Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3e182d2

Browse files
committed
Merge branch 'main' into schedule-errors/presleyp/3097
2 parents 77df9ca + 6230d55 commit 3e182d2

File tree

17 files changed

+294
-134
lines changed

17 files changed

+294
-134
lines changed

agent/agent.go

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ const (
4343
ProtocolReconnectingPTY = "reconnecting-pty"
4444
ProtocolSSH = "ssh"
4545
ProtocolDial = "dial"
46+
47+
// MagicSessionErrorCode indicates that something went wrong with the session, rather than the
48+
// command just returning a nonzero exit code, and is chosen as an arbitrary, high number
49+
// unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
50+
MagicSessionErrorCode = 229
4651
)
4752

4853
type Options struct {
@@ -273,9 +278,17 @@ func (a *agent) init(ctx context.Context) {
273278
},
274279
Handler: func(session ssh.Session) {
275280
err := a.handleSSHSession(session)
281+
var exitError *exec.ExitError
282+
if xerrors.As(err, &exitError) {
283+
a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError))
284+
_ = session.Exit(exitError.ExitCode())
285+
return
286+
}
276287
if err != nil {
277288
a.logger.Warn(ctx, "ssh session failed", slog.Error(err))
278-
_ = session.Exit(1)
289+
// This exit code is designed to be unlikely to be confused for a legit exit code
290+
// from the process.
291+
_ = session.Exit(MagicSessionErrorCode)
279292
return
280293
}
281294
},
@@ -403,7 +416,7 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
403416
return cmd, nil
404417
}
405418

406-
func (a *agent) handleSSHSession(session ssh.Session) error {
419+
func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
407420
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
408421
if err != nil {
409422
return err
@@ -426,14 +439,24 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
426439
if err != nil {
427440
return xerrors.Errorf("start command: %w", err)
428441
}
442+
defer func() {
443+
closeErr := ptty.Close()
444+
if closeErr != nil {
445+
a.logger.Warn(context.Background(), "failed to close tty",
446+
slog.Error(closeErr))
447+
if retErr == nil {
448+
retErr = closeErr
449+
}
450+
}
451+
}()
429452
err = ptty.Resize(uint16(sshPty.Window.Height), uint16(sshPty.Window.Width))
430453
if err != nil {
431454
return xerrors.Errorf("resize ptty: %w", err)
432455
}
433456
go func() {
434457
for win := range windowSize {
435-
err = ptty.Resize(uint16(win.Height), uint16(win.Width))
436-
if err != nil {
458+
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
459+
if resizeErr != nil {
437460
a.logger.Warn(context.Background(), "failed to resize tty", slog.Error(err))
438461
}
439462
}
@@ -444,9 +467,15 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
444467
go func() {
445468
_, _ = io.Copy(session, ptty.Output())
446469
}()
447-
_, _ = process.Wait()
448-
_ = ptty.Close()
449-
return nil
470+
err = process.Wait()
471+
var exitErr *exec.ExitError
472+
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
473+
// and not something to be concerned about. But, if it's something else, we should log it.
474+
if err != nil && !xerrors.As(err, &exitErr) {
475+
a.logger.Warn(context.Background(), "wait error",
476+
slog.Error(err))
477+
}
478+
return err
450479
}
451480

452481
cmd.Stdout = session
@@ -549,7 +578,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
549578
go func() {
550579
// If the process dies randomly, we should
551580
// close the pty.
552-
_, _ = process.Wait()
581+
_ = process.Wait()
553582
rpty.Close()
554583
}()
555584
go func() {

agent/agent_test.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import (
1616
"testing"
1717
"time"
1818

19+
"golang.org/x/xerrors"
20+
1921
scp "github.com/bramvdbogaerde/go-scp"
2022
"github.com/google/uuid"
2123
"github.com/pion/udp"
@@ -69,7 +71,7 @@ func TestAgent(t *testing.T) {
6971
require.True(t, strings.HasSuffix(strings.TrimSpace(string(output)), "gitssh --"))
7072
})
7173

72-
t.Run("SessionTTY", func(t *testing.T) {
74+
t.Run("SessionTTYShell", func(t *testing.T) {
7375
t.Parallel()
7476
if runtime.GOOS == "windows" {
7577
// This might be our implementation, or ConPTY itself.
@@ -103,6 +105,29 @@ func TestAgent(t *testing.T) {
103105
require.NoError(t, err)
104106
})
105107

108+
t.Run("SessionTTYExitCode", func(t *testing.T) {
109+
t.Parallel()
110+
session := setupSSHSession(t, agent.Metadata{})
111+
command := "areallynotrealcommand"
112+
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
113+
require.NoError(t, err)
114+
ptty := ptytest.New(t)
115+
require.NoError(t, err)
116+
session.Stdout = ptty.Output()
117+
session.Stderr = ptty.Output()
118+
session.Stdin = ptty.Input()
119+
err = session.Start(command)
120+
require.NoError(t, err)
121+
err = session.Wait()
122+
exitErr := &ssh.ExitError{}
123+
require.True(t, xerrors.As(err, &exitErr))
124+
if runtime.GOOS == "windows" {
125+
assert.Equal(t, 1, exitErr.ExitStatus())
126+
} else {
127+
assert.Equal(t, 127, exitErr.ExitStatus())
128+
}
129+
})
130+
106131
t.Run("LocalForwarding", func(t *testing.T) {
107132
t.Parallel()
108133
random, err := net.Listen("tcp", "127.0.0.1:0")

coderd/workspaces.go

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
288288
dbTTL, err := validWorkspaceTTLMillis(createWorkspace.TTLMillis, time.Duration(template.MaxTtl))
289289
if err != nil {
290290
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
291-
Message: "Invalid Workspace TTL.",
291+
Message: "Invalid Workspace Time to Shutdown.",
292292
Validations: []codersdk.ValidationError{{Field: "ttl_ms", Detail: err.Error()}},
293293
})
294294
return
@@ -523,8 +523,6 @@ func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
523523
return
524524
}
525525

526-
var validErrs []codersdk.ValidationError
527-
528526
err := api.Database.InTx(func(s database.Store) error {
529527
template, err := s.GetTemplateByID(r.Context(), workspace.TemplateID)
530528
if err != nil {
@@ -536,29 +534,31 @@ func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
536534

537535
dbTTL, err := validWorkspaceTTLMillis(req.TTLMillis, time.Duration(template.MaxTtl))
538536
if err != nil {
539-
validErrs = append(validErrs, codersdk.ValidationError{Field: "ttl_ms", Detail: err.Error()})
540-
return err
537+
return codersdk.ValidationError{Field: "ttl_ms", Detail: err.Error()}
541538
}
542539
if err := s.UpdateWorkspaceTTL(r.Context(), database.UpdateWorkspaceTTLParams{
543540
ID: workspace.ID,
544541
Ttl: dbTTL,
545542
}); err != nil {
546-
return xerrors.Errorf("update workspace TTL: %w", err)
543+
return xerrors.Errorf("update workspace time until shutdown: %w", err)
547544
}
548545

549546
return nil
550547
})
551548

552549
if err != nil {
553-
code := http.StatusInternalServerError
554-
if len(validErrs) > 0 {
555-
code = http.StatusBadRequest
550+
resp := codersdk.Response{
551+
Message: "Error updating workspace time until shutdown.",
556552
}
557-
httpapi.Write(rw, code, codersdk.Response{
558-
Message: "Error updating workspace time until shutdown!",
559-
Validations: validErrs,
560-
Detail: err.Error(),
561-
})
553+
var validErr codersdk.ValidationError
554+
if errors.As(err, &validErr) {
555+
resp.Validations = []codersdk.ValidationError{validErr}
556+
httpapi.Write(rw, http.StatusBadRequest, resp)
557+
return
558+
}
559+
560+
resp.Detail = err.Error()
561+
httpapi.Write(rw, http.StatusInternalServerError, resp)
562562
return
563563
}
564564

@@ -895,15 +895,15 @@ func validWorkspaceTTLMillis(millis *int64, max time.Duration) (sql.NullInt64, e
895895
dur := time.Duration(*millis) * time.Millisecond
896896
truncated := dur.Truncate(time.Minute)
897897
if truncated < time.Minute {
898-
return sql.NullInt64{}, xerrors.New("ttl must be at least one minute")
898+
return sql.NullInt64{}, xerrors.New("time until shutdown must be at least one minute")
899899
}
900900

901901
if truncated > 24*7*time.Hour {
902-
return sql.NullInt64{}, xerrors.New("ttl must be less than 7 days")
902+
return sql.NullInt64{}, xerrors.New("time until shutdown must be less than 7 days")
903903
}
904904

905905
if truncated > max {
906-
return sql.NullInt64{}, xerrors.Errorf("ttl must be below template maximum %s", max.String())
906+
return sql.NullInt64{}, xerrors.Errorf("time until shutdown must be below template maximum %s", max.String())
907907
}
908908

909909
return sql.NullInt64{

coderd/workspaces_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ func TestPostWorkspacesByOrganization(t *testing.T) {
207207
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
208208
require.Len(t, apiErr.Validations, 1)
209209
require.Equal(t, apiErr.Validations[0].Field, "ttl_ms")
210-
require.Equal(t, apiErr.Validations[0].Detail, "ttl must be at least one minute")
210+
require.Equal(t, "time until shutdown must be at least one minute", apiErr.Validations[0].Detail)
211211
})
212212

213213
t.Run("AboveMax", func(t *testing.T) {
@@ -220,7 +220,7 @@ func TestPostWorkspacesByOrganization(t *testing.T) {
220220
req := codersdk.CreateWorkspaceRequest{
221221
TemplateID: template.ID,
222222
Name: "testing",
223-
TTLMillis: ptr.Ref((24*7*time.Hour + time.Minute).Milliseconds()),
223+
TTLMillis: ptr.Ref(template.MaxTTLMillis + time.Minute.Milliseconds()),
224224
}
225225
_, err := client.CreateWorkspace(context.Background(), template.OrganizationID, req)
226226
require.Error(t, err)
@@ -229,7 +229,7 @@ func TestPostWorkspacesByOrganization(t *testing.T) {
229229
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
230230
require.Len(t, apiErr.Validations, 1)
231231
require.Equal(t, apiErr.Validations[0].Field, "ttl_ms")
232-
require.Equal(t, apiErr.Validations[0].Detail, "ttl must be less than 7 days")
232+
require.Equal(t, "time until shutdown must be less than 7 days", apiErr.Validations[0].Detail)
233233
})
234234
})
235235

@@ -934,7 +934,7 @@ func TestWorkspaceUpdateTTL(t *testing.T) {
934934
{
935935
name: "below minimum ttl",
936936
ttlMillis: ptr.Ref((30 * time.Second).Milliseconds()),
937-
expectedError: "ttl must be at least one minute",
937+
expectedError: "time until shutdown must be at least one minute",
938938
},
939939
{
940940
name: "minimum ttl",
@@ -949,12 +949,12 @@ func TestWorkspaceUpdateTTL(t *testing.T) {
949949
{
950950
name: "above maximum ttl",
951951
ttlMillis: ptr.Ref((24*7*time.Hour + time.Minute).Milliseconds()),
952-
expectedError: "ttl must be less than 7 days",
952+
expectedError: "time until shutdown must be less than 7 days",
953953
},
954954
{
955955
name: "above template maximum ttl",
956956
ttlMillis: ptr.Ref((12 * time.Hour).Milliseconds()),
957-
expectedError: "ttl_ms: ttl must be below template maximum 8h0m0s",
957+
expectedError: "ttl_ms: time until shutdown must be below template maximum 8h0m0s",
958958
modifyTemplate: func(ctr *codersdk.CreateTemplateRequest) { ctr.MaxTTLMillis = ptr.Ref((8 * time.Hour).Milliseconds()) },
959959
},
960960
}

codersdk/error.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package codersdk
22

33
import (
4+
"fmt"
45
"net"
56

67
"golang.org/x/xerrors"
@@ -32,6 +33,12 @@ type ValidationError struct {
3233
Detail string `json:"detail" validate:"required"`
3334
}
3435

36+
func (e ValidationError) Error() string {
37+
return fmt.Sprintf("field: %s detail: %s", e.Field, e.Detail)
38+
}
39+
40+
var _ error = (*ValidationError)(nil)
41+
3542
// IsConnectionErr is a convenience function for checking if the source of an
3643
// error is due to a 'connection refused', 'no such host', etc.
3744
func IsConnectionErr(err error) bool {

codersdk/workspaces.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ func (c *Client) UpdateWorkspaceTTL(ctx context.Context, id uuid.UUID, req Updat
192192
path := fmt.Sprintf("/api/v2/workspaces/%s/ttl", id.String())
193193
res, err := c.Request(ctx, http.MethodPut, path, req)
194194
if err != nil {
195-
return xerrors.Errorf("update workspace ttl: %w", err)
195+
return xerrors.Errorf("update workspace time until shutdown: %w", err)
196196
}
197197
defer res.Body.Close()
198198
if res.StatusCode != http.StatusOK {
@@ -212,7 +212,7 @@ func (c *Client) PutExtendWorkspace(ctx context.Context, id uuid.UUID, req PutEx
212212
path := fmt.Sprintf("/api/v2/workspaces/%s/extend", id.String())
213213
res, err := c.Request(ctx, http.MethodPut, path, req)
214214
if err != nil {
215-
return xerrors.Errorf("extend workspace ttl: %w", err)
215+
return xerrors.Errorf("extend workspace time until shutdown: %w", err)
216216
}
217217
defer res.Body.Close()
218218
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusNotModified {

pty/pty.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ type PTY interface {
2929
Resize(height uint16, width uint16) error
3030
}
3131

32+
// Process represents a process running in a PTY
33+
type Process interface {
34+
35+
// Wait for the command to complete. Returned error is as for exec.Cmd.Wait()
36+
Wait() error
37+
38+
// Kill the command process. Returned error is as for os.Process.Kill()
39+
Kill() error
40+
}
41+
3242
// WithFlags represents a PTY whose flags can be inspected, in particular
3343
// to determine whether local echo is enabled.
3444
type WithFlags interface {

pty/pty_other.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ package pty
55

66
import (
77
"os"
8+
"os/exec"
9+
"runtime"
810
"sync"
911

1012
"github.com/creack/pty"
@@ -27,6 +29,15 @@ type otherPty struct {
2729
pty, tty *os.File
2830
}
2931

32+
type otherProcess struct {
33+
pty *os.File
34+
cmd *exec.Cmd
35+
36+
// cmdDone protects access to cmdErr: anything reading cmdErr should read from cmdDone first.
37+
cmdDone chan any
38+
cmdErr error
39+
}
40+
3041
func (p *otherPty) Input() ReadWriter {
3142
return ReadWriter{
3243
Reader: p.tty,
@@ -66,3 +77,21 @@ func (p *otherPty) Close() error {
6677
}
6778
return nil
6879
}
80+
81+
func (p *otherProcess) Wait() error {
82+
<-p.cmdDone
83+
return p.cmdErr
84+
}
85+
86+
func (p *otherProcess) Kill() error {
87+
return p.cmd.Process.Kill()
88+
}
89+
90+
func (p *otherProcess) waitInternal() {
91+
// The GC can garbage collect the TTY FD before the command
92+
// has finished running. See:
93+
// https://github.com/creack/pty/issues/127#issuecomment-932764012
94+
p.cmdErr = p.cmd.Wait()
95+
runtime.KeepAlive(p.pty)
96+
close(p.cmdDone)
97+
}

0 commit comments

Comments
 (0)