99 "io"
1010 "net"
1111 "os/exec"
12+ "os/user"
13+ "sync"
1214 "time"
1315
1416 "cdr.dev/slog"
@@ -39,7 +41,6 @@ func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) {
3941 return nil , err
4042 }
4143 sshConn , channels , requests , err := gossh .NewClientConn (netConn , "localhost:22" , & gossh.ClientConfig {
42- User : "kyle" ,
4344 Config : gossh.Config {
4445 Ciphers : []string {"arcfour" },
4546 },
@@ -66,6 +67,7 @@ func New(dialer Dialer, options *Options) io.Closer {
6667 clientDialer : dialer ,
6768 options : options ,
6869 closeCancel : cancelFunc ,
70+ closed : make (chan struct {}),
6971 }
7072 server .init (ctx )
7173 return server
@@ -76,6 +78,7 @@ type server struct {
7678 options * Options
7779
7880 closeCancel context.CancelFunc
81+ closeMutex sync.Mutex
7982 closed chan struct {}
8083
8184 sshServer * ssh.Server
@@ -153,10 +156,19 @@ func (*server) handleSSHSession(session ssh.Session) error {
153156 err error
154157 )
155158
159+ username := session .User ()
160+ if username == "" {
161+ currentUser , err := user .Current ()
162+ if err != nil {
163+ return xerrors .Errorf ("get current user: %w" , err )
164+ }
165+ username = currentUser .Username
166+ }
167+
156168 // gliderlabs/ssh returns a command slice of zero
157169 // when a shell is requested.
158170 if len (session .Command ()) == 0 {
159- command , err = usershell .Get (session . User () )
171+ command , err = usershell .Get (username )
160172 if err != nil {
161173 return xerrors .Errorf ("get user shell: %w" , err )
162174 }
@@ -208,6 +220,7 @@ func (*server) handleSSHSession(session ssh.Session) error {
208220 _ , _ = io .Copy (session , ptty .Output ())
209221 }()
210222 _ , _ = process .Wait ()
223+ _ = ptty .Close ()
211224 return nil
212225 }
213226
@@ -254,7 +267,11 @@ func (s *server) run(ctx context.Context) {
254267 for {
255268 conn , err := peerListener .Accept ()
256269 if err != nil {
257- // This is closed!
270+ if s .isClosed () {
271+ return
272+ }
273+ s .options .Logger .Debug (ctx , "peer listener accept exited; restarting connection" , slog .Error (err ))
274+ s .run (ctx )
258275 return
259276 }
260277 go s .handlePeerConn (ctx , conn )
@@ -265,15 +282,21 @@ func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) {
265282 for {
266283 channel , err := conn .Accept (ctx )
267284 if err != nil {
268- // TODO: Log here!
285+ if s .isClosed () {
286+ return
287+ }
288+ s .options .Logger .Debug (ctx , "accept channel from peer connection" , slog .Error (err ))
269289 return
270290 }
271291
272292 switch channel .Protocol () {
273293 case "ssh" :
274294 s .sshServer .HandleConn (channel .NetConn ())
275- case "proxy" :
276- // Proxy the port provided.
295+ default :
296+ s .options .Logger .Warn (ctx , "unhandled protocol from channel" ,
297+ slog .F ("protocol" , channel .Protocol ()),
298+ slog .F ("label" , channel .Label ()),
299+ )
277300 }
278301 }
279302}
@@ -289,6 +312,13 @@ func (s *server) isClosed() bool {
289312}
290313
291314func (s * server ) Close () error {
292- s .sshServer .Close ()
315+ s .closeMutex .Lock ()
316+ defer s .closeMutex .Unlock ()
317+ if s .isClosed () {
318+ return nil
319+ }
320+ close (s .closed )
321+ s .closeCancel ()
322+ _ = s .sshServer .Close ()
293323 return nil
294324}
0 commit comments