Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 48e1cb4

Browse files
committed
fix(agent/usershell): check shell on darwin via dscl
1 parent e088303 commit 48e1cb4

File tree

4 files changed

+71
-31
lines changed

4 files changed

+71
-31
lines changed

agent/usershell/usershell_darwin.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
11
package usershell
22

3-
import "os"
3+
import (
4+
"os"
5+
"os/exec"
6+
"path/filepath"
7+
"strings"
8+
9+
"golang.org/x/xerrors"
10+
)
411

512
// Get returns the $SHELL environment variable.
6-
func Get(_ string) (string, error) {
7-
return os.Getenv("SHELL"), nil
13+
func Get(username string) (string, error) {
14+
// This command will output "UserShell: /bin/zsh" if successful, we
15+
// can ignore the error since we have fallback behavior.
16+
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output()
17+
s, ok := strings.CutPrefix(string(out), "UserShell: ")
18+
if ok {
19+
return strings.TrimSpace(s), nil
20+
}
21+
if s = os.Getenv("SHELL"); s != "" {
22+
return s, nil
23+
}
24+
return "", xerrors.Errorf("shell for user %q not found via dscl or in $SHELL", username)
825
}

agent/usershell/usershell_other.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,8 @@ func Get(username string) (string, error) {
2727
}
2828
return parts[6], nil
2929
}
30-
return "", xerrors.Errorf("user %q not found in /etc/passwd", username)
30+
if s := os.Getenv("SHELL"); s != "" {
31+
return s, nil
32+
}
33+
return "", xerrors.Errorf("shell for user %q not found in /etc/passwd or $SHELL", username)
3134
}

agent/usershell/usershell_other_test.go

Lines changed: 0 additions & 27 deletions
This file was deleted.

agent/usershell/usershell_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package usershell_test
2+
3+
import (
4+
"os/user"
5+
"runtime"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/coder/coder/agent/usershell"
11+
)
12+
13+
func TestGet(t *testing.T) {
14+
t.Parallel()
15+
t.Run("Has", func(t *testing.T) {
16+
t.Parallel()
17+
if runtime.GOOS == "windows" {
18+
t.SkipNow()
19+
}
20+
shell, err := usershell.Get("root")
21+
require.NoError(t, err)
22+
require.NotEmpty(t, shell)
23+
})
24+
}
25+
26+
//nolint:paralleltest,tparallel // This test sets an environment variable.
27+
func TestGet_NoFallback(t *testing.T) {
28+
if runtime.GOOS == "windows" {
29+
t.SkipNow()
30+
}
31+
32+
// Disable env fallback for these tests.
33+
t.Setenv("SHELL", "")
34+
35+
t.Run("NotFound", func(t *testing.T) {
36+
_, err := usershell.Get("notauser")
37+
require.Error(t, err)
38+
})
39+
40+
t.Run("User", func(t *testing.T) {
41+
u, err := user.Current()
42+
require.NoError(t, err)
43+
shell, err := usershell.Get(u.Username)
44+
require.NoError(t, err)
45+
require.NotEmpty(t, shell)
46+
})
47+
}

0 commit comments

Comments
 (0)