Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0c755fb

Browse files
committed
Fix pty tests on Windows
1 parent 2606fda commit 0c755fb

8 files changed

+58
-23
lines changed

agent/agent.go

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"io"
1010
"net"
1111
"os/exec"
12+
"os/user"
13+
"sync"
1214
"time"
1315

1416
"cdr.dev/slog"
@@ -39,7 +41,6 @@ func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) {
3941
return nil, err
4042
}
4143
sshConn, channels, requests, err := gossh.NewClientConn(netConn, "localhost:22", &gossh.ClientConfig{
42-
User: "kyle",
4344
Config: gossh.Config{
4445
Ciphers: []string{"arcfour"},
4546
},
@@ -66,6 +67,7 @@ func New(dialer Dialer, options *Options) io.Closer {
6667
clientDialer: dialer,
6768
options: options,
6869
closeCancel: cancelFunc,
70+
closed: make(chan struct{}),
6971
}
7072
server.init(ctx)
7173
return server
@@ -76,6 +78,7 @@ type server struct {
7678
options *Options
7779

7880
closeCancel context.CancelFunc
81+
closeMutex sync.Mutex
7982
closed chan struct{}
8083

8184
sshServer *ssh.Server
@@ -153,10 +156,19 @@ func (*server) handleSSHSession(session ssh.Session) error {
153156
err error
154157
)
155158

159+
username := session.User()
160+
if username == "" {
161+
currentUser, err := user.Current()
162+
if err != nil {
163+
return xerrors.Errorf("get current user: %w", err)
164+
}
165+
username = currentUser.Username
166+
}
167+
156168
// gliderlabs/ssh returns a command slice of zero
157169
// when a shell is requested.
158170
if len(session.Command()) == 0 {
159-
command, err = usershell.Get(session.User())
171+
command, err = usershell.Get(username)
160172
if err != nil {
161173
return xerrors.Errorf("get user shell: %w", err)
162174
}
@@ -208,6 +220,7 @@ func (*server) handleSSHSession(session ssh.Session) error {
208220
_, _ = io.Copy(session, ptty.Output())
209221
}()
210222
_, _ = process.Wait()
223+
_ = ptty.Close()
211224
return nil
212225
}
213226

@@ -254,7 +267,11 @@ func (s *server) run(ctx context.Context) {
254267
for {
255268
conn, err := peerListener.Accept()
256269
if err != nil {
257-
// This is closed!
270+
if s.isClosed() {
271+
return
272+
}
273+
s.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
274+
s.run(ctx)
258275
return
259276
}
260277
go s.handlePeerConn(ctx, conn)
@@ -265,15 +282,21 @@ func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) {
265282
for {
266283
channel, err := conn.Accept(ctx)
267284
if err != nil {
268-
// TODO: Log here!
285+
if s.isClosed() {
286+
return
287+
}
288+
s.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
269289
return
270290
}
271291

272292
switch channel.Protocol() {
273293
case "ssh":
274294
s.sshServer.HandleConn(channel.NetConn())
275-
case "proxy":
276-
// Proxy the port provided.
295+
default:
296+
s.options.Logger.Warn(ctx, "unhandled protocol from channel",
297+
slog.F("protocol", channel.Protocol()),
298+
slog.F("label", channel.Label()),
299+
)
277300
}
278301
}
279302
}
@@ -289,6 +312,13 @@ func (s *server) isClosed() bool {
289312
}
290313

291314
func (s *server) Close() error {
292-
s.sshServer.Close()
315+
s.closeMutex.Lock()
316+
defer s.closeMutex.Unlock()
317+
if s.isClosed() {
318+
return nil
319+
}
320+
close(s.closed)
321+
s.closeCancel()
322+
_ = s.sshServer.Close()
293323
return nil
294324
}

agent/agent_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"go.uber.org/goleak"
1212
"golang.org/x/crypto/ssh"
1313

14+
"cdr.dev/slog"
1415
"cdr.dev/slog/sloggers/slogtest"
1516
"github.com/coder/coder/agent"
1617
"github.com/coder/coder/peer"
@@ -35,7 +36,9 @@ func TestAgent(t *testing.T) {
3536
Logger: slogtest.Make(t, nil),
3637
})
3738
require.NoError(t, err)
38-
defer conn.Close()
39+
t.Cleanup(func() {
40+
_ = conn.Close()
41+
})
3942
sshClient, err := agent.DialSSHClient(conn)
4043
require.NoError(t, err)
4144
session, err := sshClient.NewSession()
@@ -58,7 +61,9 @@ func TestAgent(t *testing.T) {
5861
Logger: slogtest.Make(t, nil),
5962
})
6063
require.NoError(t, err)
61-
defer conn.Close()
64+
t.Cleanup(func() {
65+
_ = conn.Close()
66+
})
6267
sshClient, err := agent.DialSSHClient(conn)
6368
require.NoError(t, err)
6469
session, err := sshClient.NewSession()
@@ -94,7 +99,7 @@ func setup(t *testing.T) proto.DRPCPeerBrokerClient {
9499
Logger: slogtest.Make(t, nil),
95100
})
96101
}, &agent.Options{
97-
Logger: slogtest.Make(t, nil),
102+
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
98103
})
99104
t.Cleanup(func() {
100105
_ = client.Close()

agent/usershell/usershell_darwin.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package usershell
2+
3+
// Get returns the $SHELL environment variable.
4+
// TODO: This should use "dscl" to fetch the proper value. See:
5+
// https://stackoverflow.com/questions/16375519/how-to-get-the-default-shell
6+
func Get(username string) (string, error) {
7+
return os.Getenv("SHELL"), nil
8+
}

agent/usershell/usershell_other.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
//go:build !windows
2-
// +build !windows
1+
//go:build !windows && !darwin
2+
// +build !windows,!darwin
33

44
package usershell
55

agent/usershell/usershell_other_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
//go:build !windows
2-
// +build !windows
1+
//go:build !windows && !darwin
2+
// +build !windows,!darwin
33

44
package usershell_test
55

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ require (
4141
github.com/pion/logging v0.2.2
4242
github.com/pion/transport v0.13.0
4343
github.com/pion/webrtc/v3 v3.1.23
44+
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
4445
github.com/psanford/memfs v0.0.0-20210214183328-a001468d78ef
4546
github.com/quasilyte/go-ruleguard/dsl v0.3.17
4647
github.com/spf13/cobra v1.3.0
@@ -115,7 +116,6 @@ require (
115116
github.com/pion/stun v0.3.5 // indirect
116117
github.com/pion/turn/v2 v2.0.6 // indirect
117118
github.com/pion/udp v0.1.1 // indirect
118-
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
119119
github.com/pkg/errors v0.9.1 // indirect
120120
github.com/pmezard/go-difflib v1.0.0 // indirect
121121
github.com/sirupsen/logrus v1.8.1 // indirect

pty/start_other.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,5 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
3232
pty: ptty,
3333
tty: tty,
3434
}
35-
go func() {
36-
_ = cmd.Wait()
37-
_ = oPty.Close()
38-
}()
3935
return oPty, cmd.Process, nil
4036
}

pty/start_windows.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,6 @@ func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) {
8383
if err != nil {
8484
return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err)
8585
}
86-
go func() {
87-
_, _ = process.Wait()
88-
_ = pty.Close()
89-
}()
9086
return pty, process, nil
9187
}
9288

0 commit comments

Comments
 (0)