Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a90d3d6

Browse files
committed
chore: PR comments
1 parent e203bf0 commit a90d3d6

File tree

2 files changed

+53
-47
lines changed

2 files changed

+53
-47
lines changed

agent/ssh.go

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
4444
conn, ok := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
4545
if !ok {
4646
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
47-
return false, []byte{}
47+
return false, nil
4848
}
4949

5050
switch req.Type {
@@ -53,7 +53,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
5353
err := gossh.Unmarshal(req.Payload, &reqPayload)
5454
if err != nil {
5555
h.log.Warn(ctx, "parse [email protected] request payload from client", slog.Error(err))
56-
return false, []byte{}
56+
return false, nil
5757
}
5858

5959
addr := reqPayload.SocketPath
@@ -64,7 +64,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
6464
h.log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
6565
slog.F("socket_path", addr),
6666
)
67-
return false, []byte{}
67+
return false, nil
6868
}
6969

7070
// Create socket parent dir if not exists.
@@ -76,7 +76,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
7676
slog.F("socket_path", addr),
7777
slog.Error(err),
7878
)
79-
return false, []byte{}
79+
return false, nil
8080
}
8181

8282
ln, err := net.Listen("unix", addr)
@@ -85,19 +85,20 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
8585
slog.F("socket_path", addr),
8686
slog.Error(err),
8787
)
88-
return false, []byte{}
88+
return false, nil
8989
}
90+
91+
// The listener needs to successfully start before it can be added to
92+
// the map, so we don't have to worry about checking for an existing
93+
// listener.
94+
//
95+
// This is also what the upstream TCP version of this code does.
9096
h.Lock()
9197
h.forwards[addr] = ln
9298
h.Unlock()
9399
go func() {
94100
<-ctx.Done()
95-
h.Lock()
96-
ln, ok := h.forwards[addr]
97-
h.Unlock()
98-
if ok {
99-
_ = ln.Close()
100-
}
101+
_ = ln.Close()
101102
}()
102103
go func() {
103104
for {
@@ -109,6 +110,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
109110
slog.Error(err),
110111
)
111112
}
113+
// closed below
112114
break
113115
}
114116
payload := gossh.Marshal(&forwardedStreamLocalPayload{
@@ -129,9 +131,14 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
129131
Bicopy(ctx, ch, c)
130132
}()
131133
}
134+
132135
h.Lock()
133-
delete(h.forwards, addr)
136+
ln2, ok := h.forwards[addr]
137+
if ok && ln2 == ln {
138+
delete(h.forwards, addr)
139+
}
134140
h.Unlock()
141+
_ = ln.Close()
135142
}()
136143

137144
return true, nil

cli/ssh.go

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -221,30 +221,18 @@ func ssh() *cobra.Command {
221221
}
222222
}
223223

224-
// Wait for the context to be canceled, or the SSH session to end.
225-
sshErr := make(chan error)
226-
go func() {
227-
defer close(sshErr)
228-
229-
err = sshSession.Wait()
230-
if err != nil {
231-
// If the connection drops unexpectedly, we get an ExitMissingError but no other
232-
// error details, so try to at least give the user a better message
233-
if errors.Is(err, &gossh.ExitMissingError{}) {
234-
sshErr <- xerrors.New("SSH connection ended unexpectedly")
235-
return
236-
}
237-
sshErr <- err
224+
err = sshSession.Wait()
225+
if err != nil {
226+
// If the connection drops unexpectedly, we get an
227+
// ExitMissingError but no other error details, so try to at
228+
// least give the user a better message
229+
if errors.Is(err, &gossh.ExitMissingError{}) {
230+
return xerrors.New("SSH connection ended unexpectedly")
238231
}
239-
}()
240-
241-
select {
242-
case <-ctx.Done():
243-
_ = sshSession.Close()
244-
return ctx.Err()
245-
case err := <-sshErr:
246232
return err
247233
}
234+
235+
return nil
248236
},
249237
}
250238
cliflag.BoolVarP(cmd.Flags(), &stdio, "stdio", "", "CODER_SSH_STDIO", false, "Specifies whether to emit SSH output over stdin/stdout.")
@@ -456,7 +444,12 @@ func uploadGPGKeys(ctx context.Context, sshClient *gossh.Client) error {
456444
// Check if the agent is running in the workspace already.
457445
// Note: we don't support windows in the workspace for GPG forwarding so
458446
// using shell commands is fine.
459-
agentSocketBytes, err := runRemoteSSH(sshClient, nil, "set -eux; agent_socket=$(gpgconf --list-dir agent-socket); echo $agent_socket; test ! -S $agent_socket")
447+
agentSocketBytes, err := runRemoteSSH(sshClient, nil, `
448+
set -eux
449+
agent_socket=$(gpgconf --list-dir agent-socket)
450+
echo "$agent_socket"
451+
test ! -S "$agent_socket"
452+
`)
460453
agentSocket := strings.TrimSpace(string(agentSocketBytes))
461454
if err != nil {
462455
return xerrors.Errorf("check if agent socket is running (check if %q exists): %w", agentSocket, err)
@@ -540,24 +533,30 @@ func sshForwardRemote(ctx context.Context, stderr io.Writer, sshClient *gossh.Cl
540533
return
541534
}
542535

543-
localConn, err := net.Dial(localAddr.Network(), localAddr.String())
544-
if err != nil {
545-
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
546-
_ = remoteConn.Close()
547-
continue
548-
}
536+
go func() {
537+
defer func() {
538+
_ = remoteConn.Close()
539+
}()
549540

550-
if c, ok := localAddr.(cookieAddr); ok {
551-
_, err = localConn.Write(c.cookie)
541+
localConn, err := net.Dial(localAddr.Network(), localAddr.String())
552542
if err != nil {
553-
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
543+
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
544+
return
545+
}
546+
defer func() {
554547
_ = localConn.Close()
555-
_ = remoteConn.Close()
556-
continue
548+
}()
549+
550+
if c, ok := localAddr.(cookieAddr); ok {
551+
_, err = localConn.Write(c.cookie)
552+
if err != nil {
553+
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
554+
return
555+
}
557556
}
558-
}
559557

560-
go agent.Bicopy(ctx, localConn, remoteConn)
558+
agent.Bicopy(ctx, localConn, remoteConn)
559+
}()
561560
}
562561
}()
563562

0 commit comments

Comments
 (0)