@@ -21,6 +21,8 @@ import (
21
21
"github.com/coder/coder/pty"
22
22
"github.com/coder/retry"
23
23
24
+ "github.com/pkg/sftp"
25
+
24
26
"github.com/gliderlabs/ssh"
25
27
gossh "golang.org/x/crypto/ssh"
26
28
"golang.org/x/xerrors"
@@ -120,7 +122,7 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
120
122
121
123
switch channel .Protocol () {
122
124
case "ssh" :
123
- a .sshServer .HandleConn (channel .NetConn ())
125
+ go a .sshServer .HandleConn (channel .NetConn ())
124
126
default :
125
127
a .options .Logger .Warn (ctx , "unhandled protocol from channel" ,
126
128
slog .F ("protocol" , channel .Protocol ()),
@@ -145,7 +147,10 @@ func (a *agent) init(ctx context.Context) {
145
147
sshLogger := a .options .Logger .Named ("ssh-server" )
146
148
forwardHandler := & ssh.ForwardedTCPHandler {}
147
149
a .sshServer = & ssh.Server {
148
- ChannelHandlers : ssh .DefaultChannelHandlers ,
150
+ ChannelHandlers : map [string ]ssh.ChannelHandler {
151
+ "direct-tcpip" : ssh .DirectTCPIPHandler ,
152
+ "session" : ssh .DefaultSessionHandler ,
153
+ },
149
154
ConnectionFailedCallback : func (conn net.Conn , err error ) {
150
155
sshLogger .Info (ctx , "ssh connection ended" , slog .Error (err ))
151
156
},
@@ -184,61 +189,54 @@ func (a *agent) init(ctx context.Context) {
184
189
NoClientAuth : true ,
185
190
}
186
191
},
192
+ SubsystemHandlers : map [string ]ssh.SubsystemHandler {
193
+ "sftp" : func (session ssh.Session ) {
194
+ server , err := sftp .NewServer (session )
195
+ if err != nil {
196
+ a .options .Logger .Debug (session .Context (), "initialize sftp server" , slog .Error (err ))
197
+ return
198
+ }
199
+ defer server .Close ()
200
+ err = server .Serve ()
201
+ if errors .Is (err , io .EOF ) {
202
+ return
203
+ }
204
+ a .options .Logger .Debug (session .Context (), "sftp server exited with error" , slog .Error (err ))
205
+ },
206
+ },
187
207
}
188
208
189
209
go a .run (ctx )
190
210
}
191
211
192
212
func (a * agent ) handleSSHSession (session ssh.Session ) error {
193
- var (
194
- command string
195
- args = []string {}
196
- err error
197
- )
198
-
199
213
currentUser , err := user .Current ()
200
214
if err != nil {
201
215
return xerrors .Errorf ("get current user: %w" , err )
202
216
}
203
217
username := currentUser .Username
204
218
219
+ shell , err := usershell .Get (username )
220
+ if err != nil {
221
+ return xerrors .Errorf ("get user shell: %w" , err )
222
+ }
223
+
205
224
// gliderlabs/ssh returns a command slice of zero
206
225
// when a shell is requested.
226
+ command := session .RawCommand ()
207
227
if len (session .Command ()) == 0 {
208
- command , err = usershell .Get (username )
209
- if err != nil {
210
- return xerrors .Errorf ("get user shell: %w" , err )
211
- }
212
- } else {
213
- command = session .Command ()[0 ]
214
- if len (session .Command ()) > 1 {
215
- args = session .Command ()[1 :]
216
- }
228
+ command = shell
217
229
}
218
230
219
- signals := make (chan ssh.Signal )
220
- breaks := make (chan bool )
221
- defer close (signals )
222
- defer close (breaks )
223
- go func () {
224
- for {
225
- select {
226
- case <- session .Context ().Done ():
227
- return
228
- // Ignore signals and breaks for now!
229
- case <- signals :
230
- case <- breaks :
231
- }
232
- }
233
- }()
234
-
235
- cmd := exec .CommandContext (session .Context (), command , args ... )
231
+ // OpenSSH executes all commands with the users current shell.
232
+ // We replicate that behavior for IDE support.
233
+ cmd := exec .CommandContext (session .Context (), shell , "-c" , command )
236
234
cmd .Env = append (os .Environ (), session .Environ ()... )
237
235
executablePath , err := os .Executable ()
238
236
if err != nil {
239
237
return xerrors .Errorf ("getting os executable: %w" , err )
240
238
}
241
- cmd .Env = append (session . Environ () , fmt .Sprintf (`GIT_SSH_COMMAND="%s gitssh --"` , executablePath ))
239
+ cmd .Env = append (cmd . Env , fmt .Sprintf (`GIT_SSH_COMMAND="%s gitssh --"` , executablePath ))
242
240
243
241
sshPty , windowSize , isPty := session .Pty ()
244
242
if isPty {
@@ -267,7 +265,7 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
267
265
}
268
266
269
267
cmd .Stdout = session
270
- cmd .Stderr = session
268
+ cmd .Stderr = session . Stderr ()
271
269
// This blocks forever until stdin is received if we don't
272
270
// use StdinPipe. It's unknown what causes this.
273
271
stdinPipe , err := cmd .StdinPipe ()
@@ -281,8 +279,7 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
281
279
if err != nil {
282
280
return xerrors .Errorf ("start: %w" , err )
283
281
}
284
- _ = cmd .Wait ()
285
- return nil
282
+ return cmd .Wait ()
286
283
}
287
284
288
285
// isClosed returns whether the API is closed or not.
0 commit comments