Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7cf6242

Browse files
authored
test(cli): fix TestSSH/RemoteForward_Unix_Signal flake (#16172)
1 parent ea8cd55 commit 7cf6242

File tree

1 file changed

+90
-87
lines changed

1 file changed

+90
-87
lines changed

cli/ssh_test.go

+90-87
Original file line numberDiff line numberDiff line change
@@ -819,102 +819,105 @@ func TestSSH(t *testing.T) {
819819

820820
tmpdir := tempDirUnixSocket(t)
821821
localSock := filepath.Join(tmpdir, "local.sock")
822-
l, err := net.Listen("unix", localSock)
823-
require.NoError(t, err)
824-
defer l.Close()
825822
remoteSock := path.Join(tmpdir, "remote.sock")
826823
for i := 0; i < 2; i++ {
827-
t.Logf("connect %d of 2", i+1)
828-
inv, root := clitest.New(t,
829-
"ssh",
830-
workspace.Name,
831-
"--remote-forward",
832-
remoteSock+":"+localSock,
833-
)
834-
fsn := clitest.NewFakeSignalNotifier(t)
835-
inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext)
836-
inv.Stdout = io.Discard
837-
inv.Stderr = io.Discard
838-
839-
clitest.SetupConfig(t, client, root)
840-
cmdDone := tGo(t, func() {
841-
err := inv.WithContext(ctx).Run()
842-
assert.Error(t, err)
843-
})
824+
func() { // Function scope for defer.
825+
t.Logf("Connect %d/2", i+1)
826+
827+
inv, root := clitest.New(t,
828+
"ssh",
829+
workspace.Name,
830+
"--remote-forward",
831+
remoteSock+":"+localSock,
832+
)
833+
fsn := clitest.NewFakeSignalNotifier(t)
834+
inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext)
835+
inv.Stdout = io.Discard
836+
inv.Stderr = io.Discard
844837

845-
// accept a single connection
846-
msgs := make(chan string, 1)
847-
go func() {
848-
conn, err := l.Accept()
849-
if !assert.NoError(t, err) {
850-
return
851-
}
852-
msg, err := io.ReadAll(conn)
853-
if !assert.NoError(t, err) {
854-
return
855-
}
856-
msgs <- string(msg)
857-
}()
838+
clitest.SetupConfig(t, client, root)
839+
cmdDone := tGo(t, func() {
840+
err := inv.WithContext(ctx).Run()
841+
assert.Error(t, err)
842+
})
858843

859-
// Unfortunately, there is a race in crypto/ssh where it sends the request to forward
860-
// unix sockets before it is prepared to receive the response, meaning that even after
861-
// the socket exists on the file system, the client might not be ready to accept the
862-
// channel.
863-
//
864-
// https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33
865-
//
866-
// To work around this, we attempt to send messages in a loop until one succeeds
867-
success := make(chan struct{})
868-
done := make(chan struct{})
869-
go func() {
870-
defer close(done)
871-
var (
872-
conn net.Conn
873-
err error
874-
)
875-
for {
876-
time.Sleep(testutil.IntervalMedium)
877-
select {
878-
case <-ctx.Done():
879-
t.Error("timeout")
880-
return
881-
case <-success:
844+
// accept a single connection
845+
msgs := make(chan string, 1)
846+
l, err := net.Listen("unix", localSock)
847+
require.NoError(t, err)
848+
defer l.Close()
849+
go func() {
850+
conn, err := l.Accept()
851+
if !assert.NoError(t, err) {
882852
return
883-
default:
884-
// Ok
885853
}
886-
conn, err = net.Dial("unix", remoteSock)
887-
if err != nil {
888-
t.Logf("dial error: %s", err)
889-
continue
890-
}
891-
_, err = conn.Write([]byte("test"))
892-
if err != nil {
893-
t.Logf("write error: %s", err)
854+
msg, err := io.ReadAll(conn)
855+
if !assert.NoError(t, err) {
856+
return
894857
}
895-
err = conn.Close()
896-
if err != nil {
897-
t.Logf("close error: %s", err)
858+
msgs <- string(msg)
859+
}()
860+
861+
// Unfortunately, there is a race in crypto/ssh where it sends the request to forward
862+
// unix sockets before it is prepared to receive the response, meaning that even after
863+
// the socket exists on the file system, the client might not be ready to accept the
864+
// channel.
865+
//
866+
// https://cs.opensource.google/go/x/crypto/+/master:ssh/streamlocal.go;drc=2fc4c88bf43f0ea5ea305eae2b7af24b2cc93287;l=33
867+
//
868+
// To work around this, we attempt to send messages in a loop until one succeeds
869+
success := make(chan struct{})
870+
done := make(chan struct{})
871+
go func() {
872+
defer close(done)
873+
var (
874+
conn net.Conn
875+
err error
876+
)
877+
for {
878+
time.Sleep(testutil.IntervalMedium)
879+
select {
880+
case <-ctx.Done():
881+
t.Error("timeout")
882+
return
883+
case <-success:
884+
return
885+
default:
886+
// Ok
887+
}
888+
conn, err = net.Dial("unix", remoteSock)
889+
if err != nil {
890+
t.Logf("dial error: %s", err)
891+
continue
892+
}
893+
_, err = conn.Write([]byte("test"))
894+
if err != nil {
895+
t.Logf("write error: %s", err)
896+
}
897+
err = conn.Close()
898+
if err != nil {
899+
t.Logf("close error: %s", err)
900+
}
898901
}
899-
}
900-
}()
902+
}()
901903

902-
msg := testutil.RequireRecvCtx(ctx, t, msgs)
903-
require.Equal(t, "test", msg)
904-
close(success)
905-
fsn.Notify()
906-
<-cmdDone
907-
fsn.AssertStopped()
908-
// wait for dial goroutine to complete
909-
_ = testutil.RequireRecvCtx(ctx, t, done)
910-
911-
// wait for the remote socket to get cleaned up before retrying,
912-
// because cleaning up the socket happens asynchronously, and we
913-
// might connect to an old listener on the agent side.
914-
require.Eventually(t, func() bool {
915-
_, err = os.Stat(remoteSock)
916-
return xerrors.Is(err, os.ErrNotExist)
917-
}, testutil.WaitShort, testutil.IntervalFast)
904+
msg := testutil.RequireRecvCtx(ctx, t, msgs)
905+
require.Equal(t, "test", msg)
906+
close(success)
907+
fsn.Notify()
908+
<-cmdDone
909+
fsn.AssertStopped()
910+
// wait for dial goroutine to complete
911+
_ = testutil.RequireRecvCtx(ctx, t, done)
912+
913+
// wait for the remote socket to get cleaned up before retrying,
914+
// because cleaning up the socket happens asynchronously, and we
915+
// might connect to an old listener on the agent side.
916+
require.Eventually(t, func() bool {
917+
_, err = os.Stat(remoteSock)
918+
return xerrors.Is(err, os.ErrNotExist)
919+
}, testutil.WaitShort, testutil.IntervalFast)
920+
}()
918921
}
919922
})
920923

0 commit comments

Comments
 (0)