diff --git a/agent/agent.go b/agent/agent.go new file mode 100644 index 0000000000000..285efe3dc9836 --- /dev/null +++ b/agent/agent.go @@ -0,0 +1,329 @@ +package agent + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + "io" + "net" + "os/exec" + "os/user" + "sync" + "time" + + "cdr.dev/slog" + "github.com/coder/coder/agent/usershell" + "github.com/coder/coder/peer" + "github.com/coder/coder/peerbroker" + "github.com/coder/coder/pty" + "github.com/coder/retry" + + "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" +) + +func DialSSH(conn *peer.Conn) (net.Conn, error) { + channel, err := conn.Dial(context.Background(), "ssh", &peer.ChannelOptions{ + Protocol: "ssh", + }) + if err != nil { + return nil, err + } + return channel.NetConn(), nil +} + +func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) { + netConn, err := DialSSH(conn) + if err != nil { + return nil, err + } + sshConn, channels, requests, err := gossh.NewClientConn(netConn, "localhost:22", &gossh.ClientConfig{ + Config: gossh.Config{ + Ciphers: []string{"arcfour"}, + }, + // SSH host validation isn't helpful, because obtaining a peer + // connection already signifies user-intent to dial a workspace. + // #nosec + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + return nil, err + } + return gossh.NewClient(sshConn, channels, requests), nil +} + +type Options struct { + Logger slog.Logger +} + +type Dialer func(ctx context.Context) (*peerbroker.Listener, error) + +func New(dialer Dialer, options *Options) io.Closer { + ctx, cancelFunc := context.WithCancel(context.Background()) + server := &server{ + clientDialer: dialer, + options: options, + closeCancel: cancelFunc, + closed: make(chan struct{}), + } + server.init(ctx) + return server +} + +type server struct { + clientDialer Dialer + options *Options + + closeCancel context.CancelFunc + closeMutex sync.Mutex + closed chan struct{} + + sshServer *ssh.Server +} + +func (s *server) init(ctx context.Context) { + // Clients' should ignore the host key when connecting. + // The agent needs to authenticate with coderd to SSH, + // so SSH authentication doesn't improve security. + randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + randomSigner, err := gossh.NewSignerFromKey(randomHostKey) + if err != nil { + panic(err) + } + sshLogger := s.options.Logger.Named("ssh-server") + forwardHandler := &ssh.ForwardedTCPHandler{} + s.sshServer = &ssh.Server{ + ChannelHandlers: ssh.DefaultChannelHandlers, + ConnectionFailedCallback: func(conn net.Conn, err error) { + sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) + }, + Handler: func(session ssh.Session) { + err := s.handleSSHSession(session) + if err != nil { + s.options.Logger.Debug(ctx, "ssh session failed", slog.Error(err)) + _ = session.Exit(1) + return + } + }, + HostSigners: []ssh.Signer{randomSigner}, + LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { + // Allow local port forwarding all! + sshLogger.Debug(ctx, "local port forward", + slog.F("destination-host", destinationHost), + slog.F("destination-port", destinationPort)) + return true + }, + PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { + return true + }, + ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { + // Allow reverse port forwarding all! + sshLogger.Debug(ctx, "local port forward", + slog.F("bind-host", bindHost), + slog.F("bind-port", bindPort)) + return true + }, + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": forwardHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, + }, + ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { + return &gossh.ServerConfig{ + Config: gossh.Config{ + // "arcfour" is the fastest SSH cipher. We prioritize throughput + // over encryption here, because the WebRTC connection is already + // encrypted. If possible, we'd disable encryption entirely here. + Ciphers: []string{"arcfour"}, + }, + NoClientAuth: true, + } + }, + } + + go s.run(ctx) +} + +func (*server) handleSSHSession(session ssh.Session) error { + var ( + command string + args = []string{} + err error + ) + + username := session.User() + if username == "" { + currentUser, err := user.Current() + if err != nil { + return xerrors.Errorf("get current user: %w", err) + } + username = currentUser.Username + } + + // gliderlabs/ssh returns a command slice of zero + // when a shell is requested. + if len(session.Command()) == 0 { + command, err = usershell.Get(username) + if err != nil { + return xerrors.Errorf("get user shell: %w", err) + } + } else { + command = session.Command()[0] + if len(session.Command()) > 1 { + args = session.Command()[1:] + } + } + + signals := make(chan ssh.Signal) + breaks := make(chan bool) + defer close(signals) + defer close(breaks) + go func() { + for { + select { + case <-session.Context().Done(): + return + // Ignore signals and breaks for now! + case <-signals: + case <-breaks: + } + } + }() + + cmd := exec.CommandContext(session.Context(), command, args...) + cmd.Env = session.Environ() + + sshPty, windowSize, isPty := session.Pty() + if isPty { + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) + ptty, process, err := pty.Start(cmd) + if err != nil { + return xerrors.Errorf("start command: %w", err) + } + go func() { + for win := range windowSize { + err := ptty.Resize(uint16(win.Width), uint16(win.Height)) + if err != nil { + panic(err) + } + } + }() + go func() { + _, _ = io.Copy(ptty.Input(), session) + }() + go func() { + _, _ = io.Copy(session, ptty.Output()) + }() + _, _ = process.Wait() + _ = ptty.Close() + return nil + } + + cmd.Stdout = session + cmd.Stderr = session + // This blocks forever until stdin is received if we don't + // use StdinPipe. It's unknown what causes this. + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return xerrors.Errorf("create stdin pipe: %w", err) + } + go func() { + _, _ = io.Copy(stdinPipe, session) + }() + err = cmd.Start() + if err != nil { + return xerrors.Errorf("start: %w", err) + } + _ = cmd.Wait() + return nil +} + +func (s *server) run(ctx context.Context) { + var peerListener *peerbroker.Listener + var err error + // An exponential back-off occurs when the connection is failing to dial. + // This is to prevent server spam in case of a coderd outage. + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { + peerListener, err = s.clientDialer(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + if s.isClosed() { + return + } + s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) + continue + } + s.options.Logger.Debug(context.Background(), "connected") + break + } + select { + case <-ctx.Done(): + return + default: + } + + for { + conn, err := peerListener.Accept() + if err != nil { + if s.isClosed() { + return + } + s.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err)) + s.run(ctx) + return + } + go s.handlePeerConn(ctx, conn) + } +} + +func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) { + for { + channel, err := conn.Accept(ctx) + if err != nil { + if errors.Is(err, peer.ErrClosed) || s.isClosed() { + return + } + s.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err)) + return + } + + switch channel.Protocol() { + case "ssh": + s.sshServer.HandleConn(channel.NetConn()) + default: + s.options.Logger.Warn(ctx, "unhandled protocol from channel", + slog.F("protocol", channel.Protocol()), + slog.F("label", channel.Label()), + ) + } + } +} + +// isClosed returns whether the API is closed or not. +func (s *server) isClosed() bool { + select { + case <-s.closed: + return true + default: + return false + } +} + +func (s *server) Close() error { + s.closeMutex.Lock() + defer s.closeMutex.Unlock() + if s.isClosed() { + return nil + } + close(s.closed) + s.closeCancel() + _ = s.sshServer.Close() + return nil +} diff --git a/agent/agent_test.go b/agent/agent_test.go new file mode 100644 index 0000000000000..662c054eae146 --- /dev/null +++ b/agent/agent_test.go @@ -0,0 +1,110 @@ +package agent_test + +import ( + "context" + "runtime" + "strings" + "testing" + + "github.com/pion/webrtc/v3" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "golang.org/x/crypto/ssh" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/agent" + "github.com/coder/coder/peer" + "github.com/coder/coder/peerbroker" + "github.com/coder/coder/peerbroker/proto" + "github.com/coder/coder/provisionersdk" + "github.com/coder/coder/pty/ptytest" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestAgent(t *testing.T) { + t.Parallel() + t.Run("SessionExec", func(t *testing.T) { + t.Parallel() + api := setup(t) + stream, err := api.NegotiateConnection(context.Background()) + require.NoError(t, err) + conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + sshClient, err := agent.DialSSHClient(conn) + require.NoError(t, err) + session, err := sshClient.NewSession() + require.NoError(t, err) + command := "echo test" + if runtime.GOOS == "windows" { + command = "cmd.exe /c echo test" + } + output, err := session.Output(command) + require.NoError(t, err) + require.Equal(t, "test", strings.TrimSpace(string(output))) + }) + + t.Run("SessionTTY", func(t *testing.T) { + t.Parallel() + api := setup(t) + stream, err := api.NegotiateConnection(context.Background()) + require.NoError(t, err) + conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + sshClient, err := agent.DialSSHClient(conn) + require.NoError(t, err) + session, err := sshClient.NewSession() + require.NoError(t, err) + prompt := "$" + command := "bash" + if runtime.GOOS == "windows" { + command = "cmd.exe" + prompt = ">" + } + err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{}) + require.NoError(t, err) + ptty := ptytest.New(t) + require.NoError(t, err) + session.Stdout = ptty.Output() + session.Stderr = ptty.Output() + session.Stdin = ptty.Input() + err = session.Start(command) + require.NoError(t, err) + ptty.ExpectMatch(prompt) + ptty.WriteLine("echo test") + ptty.ExpectMatch("test") + ptty.WriteLine("exit") + err = session.Wait() + require.NoError(t, err) + }) +} + +func setup(t *testing.T) proto.DRPCPeerBrokerClient { + client, server := provisionersdk.TransportPipe() + closer := agent.New(func(ctx context.Context) (*peerbroker.Listener, error) { + return peerbroker.Listen(server, &peer.ConnOptions{ + Logger: slogtest.Make(t, nil), + }) + }, &agent.Options{ + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + }) + t.Cleanup(func() { + _ = client.Close() + _ = server.Close() + _ = closer.Close() + }) + return proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client)) +} diff --git a/agent/usershell/usershell_darwin.go b/agent/usershell/usershell_darwin.go new file mode 100644 index 0000000000000..d2b9a454e0470 --- /dev/null +++ b/agent/usershell/usershell_darwin.go @@ -0,0 +1,10 @@ +package usershell + +import "os" + +// Get returns the $SHELL environment variable. +// TODO: This should use "dscl" to fetch the proper value. See: +// https://stackoverflow.com/questions/16375519/how-to-get-the-default-shell +func Get(username string) (string, error) { + return os.Getenv("SHELL"), nil +} diff --git a/agent/usershell/usershell_other.go b/agent/usershell/usershell_other.go new file mode 100644 index 0000000000000..6f69a1e270ac3 --- /dev/null +++ b/agent/usershell/usershell_other.go @@ -0,0 +1,31 @@ +//go:build !windows && !darwin +// +build !windows,!darwin + +package usershell + +import ( + "os" + "strings" + + "golang.org/x/xerrors" +) + +// Get returns the /etc/passwd entry for the username provided. +func Get(username string) (string, error) { + contents, err := os.ReadFile("/etc/passwd") + if err != nil { + return "", xerrors.Errorf("read /etc/passwd: %w", err) + } + lines := strings.Split(string(contents), "\n") + for _, line := range lines { + if !strings.HasPrefix(line, username+":") { + continue + } + parts := strings.Split(line, ":") + if len(parts) < 7 { + return "", xerrors.Errorf("malformed user entry: %q", line) + } + return parts[6], nil + } + return "", xerrors.New("user not found in /etc/passwd and $SHELL not set") +} diff --git a/agent/usershell/usershell_other_test.go b/agent/usershell/usershell_other_test.go new file mode 100644 index 0000000000000..9469f31c70e70 --- /dev/null +++ b/agent/usershell/usershell_other_test.go @@ -0,0 +1,27 @@ +//go:build !windows && !darwin +// +build !windows,!darwin + +package usershell_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/agent/usershell" +) + +func TestGet(t *testing.T) { + t.Parallel() + t.Run("Has", func(t *testing.T) { + t.Parallel() + shell, err := usershell.Get("root") + require.NoError(t, err) + require.NotEmpty(t, shell) + }) + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + _, err := usershell.Get("notauser") + require.Error(t, err) + }) +} diff --git a/agent/usershell/usershell_windows.go b/agent/usershell/usershell_windows.go new file mode 100644 index 0000000000000..91bff1d8297cd --- /dev/null +++ b/agent/usershell/usershell_windows.go @@ -0,0 +1,6 @@ +package usershell + +// Get returns the command prompt binary name. +func Get(username string) (string, error) { + return "cmd.exe", nil +} diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 889f6241a442a..40eba2fb53942 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -164,7 +164,7 @@ func AwaitProjectImportJob(t *testing.T, client *codersdk.Client, organization s provisionerJob, err = client.ProjectImportJob(context.Background(), organization, job) require.NoError(t, err) return provisionerJob.Status.Completed() - }, 3*time.Second, 25*time.Millisecond) + }, 5*time.Second, 25*time.Millisecond) return provisionerJob } @@ -176,7 +176,7 @@ func AwaitWorkspaceProvisionJob(t *testing.T, client *codersdk.Client, organizat provisionerJob, err = client.WorkspaceProvisionJob(context.Background(), organization, job) require.NoError(t, err) return provisionerJob.Status.Completed() - }, 3*time.Second, 25*time.Millisecond) + }, 5*time.Second, 25*time.Millisecond) return provisionerJob } diff --git a/go.mod b/go.mod index 61fb3b0e83cee..c018ade5929d3 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/coder/retry v1.3.0 github.com/creack/pty v1.1.17 github.com/fatih/color v1.13.0 + github.com/gliderlabs/ssh v0.3.3 github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/render v1.0.1 github.com/go-playground/validator/v10 v10.10.0 @@ -64,6 +65,7 @@ require ( github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/alecthomas/chroma v0.10.0 // indirect + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/cenkalti/backoff/v4 v4.1.2 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect diff --git a/go.sum b/go.sum index 5db435337eecd..cb9a8a7679c7a 100644 --- a/go.sum +++ b/go.sum @@ -132,6 +132,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/alexflint/go-filemutex v0.0.0-20171022225611-72bdc8eae2ae/go.mod h1:CgnQgUtFrFz9mxFNtED3jI5tLDjKlOM+oUF/sTk6ps0= github.com/andybalholm/crlf v0.0.0-20171020200849-670099aa064f/go.mod h1:k8feO4+kXDxro6ErPXBRTJ/ro2mf0SsFG8s7doP9kJE= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/arrow v0.0.0-20210818145353-234c94e4ce64/go.mod h1:2qMFB56yOP3KzkB3PbYZ4AlUFg3a88F67TIx5lB/WwY= github.com/apache/arrow/go/arrow v0.0.0-20211013220434-5962184e7a30/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= @@ -442,6 +444,8 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/gliderlabs/ssh v0.3.3 h1:mBQ8NiOgDkINJrZtoizkC3nDNYgSaWtxyem6S2XHBtA= +github.com/gliderlabs/ssh v0.3.3/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= diff --git a/pty/pty_other.go b/pty/pty_other.go index e2520a2387116..c3933878456cb 100644 --- a/pty/pty_other.go +++ b/pty/pty_other.go @@ -45,7 +45,7 @@ func (p *otherPty) Output() io.ReadWriter { func (p *otherPty) Resize(cols uint16, rows uint16) error { p.mutex.Lock() defer p.mutex.Unlock() - return pty.Setsize(p.tty, &pty.Winsize{ + return pty.Setsize(p.pty, &pty.Winsize{ Rows: rows, Cols: cols, }) diff --git a/pty/pty_windows.go b/pty/pty_windows.go index b6a9f8ae2e5dd..fa6f1932a48c3 100644 --- a/pty/pty_windows.go +++ b/pty/pty_windows.go @@ -96,12 +96,15 @@ func (p *ptyWindows) Close() error { return nil } p.closed = true + _ = p.outputWrite.Close() + _ = p.outputRead.Close() + _ = p.inputWrite.Close() + _ = p.inputRead.Close() ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) - if ret != 0 { + if ret < 0 { return xerrors.Errorf("close pseudo console: %w", err) } - _ = p.outputRead.Close() - _ = p.inputWrite.Close() + return nil } diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go index 7ea5b7a119f0d..60cd88ce606a2 100644 --- a/pty/ptytest/ptytest.go +++ b/pty/ptytest/ptytest.go @@ -5,8 +5,10 @@ import ( "bytes" "fmt" "io" + "os" "os/exec" "regexp" + "runtime" "strings" "testing" "unicode/utf8" @@ -28,10 +30,10 @@ func New(t *testing.T) *PTY { return create(t, ptty) } -func Start(t *testing.T, cmd *exec.Cmd) *PTY { - ptty, err := pty.Start(cmd) +func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) { + ptty, ps, err := pty.Start(cmd) require.NoError(t, err) - return create(t, ptty) + return create(t, ptty), ps } func create(t *testing.T, ptty pty.PTY) *PTY { @@ -86,10 +88,15 @@ func (p *PTY) ExpectMatch(str string) string { break } } + p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), "")) return buffer.String() } func (p *PTY) WriteLine(str string) { - _, err := fmt.Fprintf(p.PTY.Input(), "%s\n", str) + newline := "\n" + if runtime.GOOS == "windows" { + newline = "\r\n" + } + _, err := fmt.Fprintf(p.PTY.Input(), "%s%s", str, newline) require.NoError(p.t, err) } diff --git a/pty/start.go b/pty/start.go index 2b75843ee16c2..d0cbcd667d7b7 100644 --- a/pty/start.go +++ b/pty/start.go @@ -1,7 +1,10 @@ package pty -import "os/exec" +import ( + "os" + "os/exec" +) -func Start(cmd *exec.Cmd) (PTY, error) { +func Start(cmd *exec.Cmd) (PTY, *os.Process, error) { return startPty(cmd) } diff --git a/pty/start_other.go b/pty/start_other.go index 2f1a74633130e..6709cb271b1e4 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -4,6 +4,7 @@ package pty import ( + "os" "os/exec" "syscall" @@ -11,10 +12,10 @@ import ( "golang.org/x/xerrors" ) -func startPty(cmd *exec.Cmd) (PTY, error) { +func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) { ptty, tty, err := pty.Open() if err != nil { - return nil, xerrors.Errorf("open: %w", err) + return nil, nil, xerrors.Errorf("open: %w", err) } defer func() { _ = tty.Close() @@ -29,10 +30,11 @@ func startPty(cmd *exec.Cmd) (PTY, error) { err = cmd.Start() if err != nil { _ = ptty.Close() - return nil, xerrors.Errorf("start: %w", err) + return nil, nil, xerrors.Errorf("start: %w", err) } - return &otherPty{ + oPty := &otherPty{ pty: ptty, tty: tty, - }, nil + } + return oPty, cmd.Process, nil } diff --git a/pty/start_other_test.go b/pty/start_other_test.go index a5e7d94b36af1..30c87935bcd69 100644 --- a/pty/start_other_test.go +++ b/pty/start_other_test.go @@ -7,8 +7,9 @@ import ( "os/exec" "testing" - "github.com/coder/coder/pty/ptytest" "go.uber.org/goleak" + + "github.com/coder/coder/pty/ptytest" ) func TestMain(m *testing.M) { @@ -19,7 +20,7 @@ func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { t.Parallel() - pty := ptytest.Start(t, exec.Command("echo", "test")) + pty, _ := ptytest.Start(t, exec.Command("echo", "test")) pty.ExpectMatch("test") }) } diff --git a/pty/start_windows.go b/pty/start_windows.go index 136ba245736ab..1019a969aef2c 100644 --- a/pty/start_windows.go +++ b/pty/start_windows.go @@ -11,47 +11,48 @@ import ( "unsafe" "golang.org/x/sys/windows" + "golang.org/x/xerrors" ) // Allocates a PTY and starts the specified command attached to it. // See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process -func startPty(cmd *exec.Cmd) (PTY, error) { +func startPty(cmd *exec.Cmd) (PTY, *os.Process, error) { fullPath, err := exec.LookPath(cmd.Path) if err != nil { - return nil, err + return nil, nil, err } pathPtr, err := windows.UTF16PtrFromString(fullPath) if err != nil { - return nil, err + return nil, nil, err } argsPtr, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(cmd.Args)) if err != nil { - return nil, err + return nil, nil, err } if cmd.Dir == "" { cmd.Dir, err = os.Getwd() if err != nil { - return nil, err + return nil, nil, err } } dirPtr, err := windows.UTF16PtrFromString(cmd.Dir) if err != nil { - return nil, err + return nil, nil, err } pty, err := newPty() if err != nil { - return nil, err + return nil, nil, err } winPty := pty.(*ptyWindows) attrs, err := windows.NewProcThreadAttributeList(1) if err != nil { - return nil, err + return nil, nil, err } // Taken from: https://github.com/microsoft/hcsshim/blob/2314362e977aa03b3ed245a4beb12d00422af0e2/internal/winapi/process.go#L6 err = attrs.Update(0x20016, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console)) if err != nil { - return nil, err + return nil, nil, err } startupInfo := &windows.StartupInfoEx{} @@ -73,12 +74,16 @@ func startPty(cmd *exec.Cmd) (PTY, error) { &processInfo, ) if err != nil { - return nil, err + return nil, nil, err } defer windows.CloseHandle(processInfo.Thread) defer windows.CloseHandle(processInfo.Process) - return pty, nil + process, err := os.FindProcess(int(processInfo.ProcessId)) + if err != nil { + return nil, nil, xerrors.Errorf("find process %d: %w", processInfo.ProcessId, err) + } + return pty, process, nil } // Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476 diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go index faee269776830..d0398d0dec019 100644 --- a/pty/start_windows_test.go +++ b/pty/start_windows_test.go @@ -20,12 +20,12 @@ func TestStart(t *testing.T) { t.Parallel() t.Run("Echo", func(t *testing.T) { t.Parallel() - pty := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) + pty, _ := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) pty.ExpectMatch("test") }) t.Run("Resize", func(t *testing.T) { t.Parallel() - pty := ptytest.Start(t, exec.Command("cmd.exe")) + pty, _ := ptytest.Start(t, exec.Command("cmd.exe")) err := pty.Resize(100, 50) require.NoError(t, err) }) diff --git a/templates/null/main.tf b/templates/null/main.tf deleted file mode 100644 index 9bb3f2042e2a4..0000000000000 --- a/templates/null/main.tf +++ /dev/null @@ -1,5 +0,0 @@ -variable "bananas" { - description = "hello!" -} - -resource "null_resource" "example" {}