From 39a9baba167af62cfc47a871e757636d04790bf3 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 7 Jul 2023 12:26:18 +0000 Subject: [PATCH] fix(agent/usershell): check shell on darwin via dscl --- agent/usershell/usershell_darwin.go | 23 +++++++++++-- agent/usershell/usershell_other.go | 5 ++- agent/usershell/usershell_other_test.go | 27 --------------- agent/usershell/usershell_test.go | 46 +++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 31 deletions(-) delete mode 100644 agent/usershell/usershell_other_test.go create mode 100644 agent/usershell/usershell_test.go diff --git a/agent/usershell/usershell_darwin.go b/agent/usershell/usershell_darwin.go index 532474f628b1e..47c4a4d21f869 100644 --- a/agent/usershell/usershell_darwin.go +++ b/agent/usershell/usershell_darwin.go @@ -1,8 +1,25 @@ package usershell -import "os" +import ( + "os" + "os/exec" + "path/filepath" + "strings" + + "golang.org/x/xerrors" +) // Get returns the $SHELL environment variable. -func Get(_ string) (string, error) { - return os.Getenv("SHELL"), nil +func Get(username string) (string, error) { + // This command will output "UserShell: /bin/zsh" if successful, we + // can ignore the error since we have fallback behavior. + out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output() + s, ok := strings.CutPrefix(string(out), "UserShell: ") + if ok { + return strings.TrimSpace(s), nil + } + if s = os.Getenv("SHELL"); s != "" { + return s, nil + } + return "", xerrors.Errorf("shell for user %q not found via dscl or in $SHELL", username) } diff --git a/agent/usershell/usershell_other.go b/agent/usershell/usershell_other.go index 230555de58d8c..d015b7ebf4111 100644 --- a/agent/usershell/usershell_other.go +++ b/agent/usershell/usershell_other.go @@ -27,5 +27,8 @@ func Get(username string) (string, error) { } return parts[6], nil } - return "", xerrors.Errorf("user %q not found in /etc/passwd", username) + if s := os.Getenv("SHELL"); s != "" { + return s, nil + } + return "", xerrors.Errorf("shell for user %q not found in /etc/passwd or $SHELL", username) } diff --git a/agent/usershell/usershell_other_test.go b/agent/usershell/usershell_other_test.go deleted file mode 100644 index 9469f31c70e70..0000000000000 --- a/agent/usershell/usershell_other_test.go +++ /dev/null @@ -1,27 +0,0 @@ -//go:build !windows && !darwin -// +build !windows,!darwin - -package usershell_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/coder/coder/agent/usershell" -) - -func TestGet(t *testing.T) { - t.Parallel() - t.Run("Has", func(t *testing.T) { - t.Parallel() - shell, err := usershell.Get("root") - require.NoError(t, err) - require.NotEmpty(t, shell) - }) - t.Run("NotFound", func(t *testing.T) { - t.Parallel() - _, err := usershell.Get("notauser") - require.Error(t, err) - }) -} diff --git a/agent/usershell/usershell_test.go b/agent/usershell/usershell_test.go new file mode 100644 index 0000000000000..676ee462ffe63 --- /dev/null +++ b/agent/usershell/usershell_test.go @@ -0,0 +1,46 @@ +package usershell_test + +import ( + "os/user" + "runtime" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/agent/usershell" +) + +//nolint:paralleltest,tparallel // This test sets an environment variable. +func TestGet(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + t.Run("Fallback", func(t *testing.T) { + t.Setenv("SHELL", "/bin/sh") + + t.Run("NonExistentUser", func(t *testing.T) { + shell, err := usershell.Get("notauser") + require.NoError(t, err) + require.Equal(t, "/bin/sh", shell) + }) + }) + + t.Run("NoFallback", func(t *testing.T) { + // Disable env fallback for these tests. + t.Setenv("SHELL", "") + + t.Run("NotFound", func(t *testing.T) { + _, err := usershell.Get("notauser") + require.Error(t, err) + }) + + t.Run("User", func(t *testing.T) { + u, err := user.Current() + require.NoError(t, err) + shell, err := usershell.Get(u.Username) + require.NoError(t, err) + require.NotEmpty(t, shell) + }) + }) +}