diff --git a/.vscode/settings.json b/.vscode/settings.json index d9b2b88f1798c..02c3b05cc42c5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -33,6 +33,7 @@ "drpcserver", "fatih", "goleak", + "gossh", "hashicorp", "httpmw", "isatty", @@ -51,9 +52,12 @@ "protobuf", "provisionerd", "provisionersdk", + "ptty", + "ptytest", "retrier", "sdkproto", "stretchr", + "tcpip", "tfexec", "tfstate", "unconvert", diff --git a/cli/clitest/clitest_test.go b/cli/clitest/clitest_test.go index fa11db7c04e9b..b1bd908bf6128 100644 --- a/cli/clitest/clitest_test.go +++ b/cli/clitest/clitest_test.go @@ -8,7 +8,7 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" + "github.com/coder/coder/pty/ptytest" ) func TestMain(m *testing.M) { @@ -21,11 +21,12 @@ func TestCli(t *testing.T) { client := coderdtest.New(t) cmd, config := clitest.New(t) clitest.SetupConfig(t, client, config) - cons := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) go func() { err := cmd.Execute() require.NoError(t, err) }() - _, err := cons.ExpectString("coder") - require.NoError(t, err) + pty.ExpectMatch("coder") } diff --git a/cli/login_test.go b/cli/login_test.go index b6c581cc41f12..24caf18e1aa3f 100644 --- a/cli/login_test.go +++ b/cli/login_test.go @@ -3,10 +3,11 @@ package cli_test import ( "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" - "github.com/stretchr/testify/require" + "github.com/coder/coder/pty/ptytest" ) func TestLogin(t *testing.T) { @@ -26,7 +27,9 @@ func TestLogin(t *testing.T) { // accurately detect Windows ptys when they are not attached to a process: // https://github.com/mattn/go-isatty/issues/59 root, _ := clitest.New(t, "login", client.URL.String(), "--force-tty") - cons := console.New(t, root) + pty := ptytest.New(t) + root.SetIn(pty.Input()) + root.SetOut(pty.Output()) go func() { err := root.Execute() require.NoError(t, err) @@ -42,12 +45,9 @@ func TestLogin(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := cons.ExpectString(match) - require.NoError(t, err) - _, err = cons.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } - _, err := cons.ExpectString("Welcome to Coder") - require.NoError(t, err) + pty.ExpectMatch("Welcome to Coder") }) } diff --git a/cli/projectcreate_test.go b/cli/projectcreate_test.go index 6311aaf141f30..873a276263e5a 100644 --- a/cli/projectcreate_test.go +++ b/cli/projectcreate_test.go @@ -7,10 +7,10 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" "github.com/coder/coder/database" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/pty/ptytest" ) func TestProjectCreate(t *testing.T) { @@ -26,7 +26,9 @@ func TestProjectCreate(t *testing.T) { cmd, root := clitest.New(t, "projects", "create", "--directory", source, "--provisioner", string(database.ProvisionerTypeEcho)) clitest.SetupConfig(t, client, root) _ = coderdtest.NewProvisionerDaemon(t, client) - console := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) closeChan := make(chan struct{}) go func() { err := cmd.Execute() @@ -43,10 +45,8 @@ func TestProjectCreate(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := console.ExpectString(match) - require.NoError(t, err) - _, err = console.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } <-closeChan }) @@ -73,7 +73,9 @@ func TestProjectCreate(t *testing.T) { cmd, root := clitest.New(t, "projects", "create", "--directory", source, "--provisioner", string(database.ProvisionerTypeEcho)) clitest.SetupConfig(t, client, root) coderdtest.NewProvisionerDaemon(t, client) - cons := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) closeChan := make(chan struct{}) go func() { err := cmd.Execute() @@ -91,10 +93,8 @@ func TestProjectCreate(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := cons.ExpectString(match) - require.NoError(t, err) - _, err = cons.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } <-closeChan }) diff --git a/cli/root.go b/cli/root.go index f4e27a49d9e67..55e2b4c1d65ef 100644 --- a/cli/root.go +++ b/cli/root.go @@ -12,7 +12,6 @@ import ( "github.com/manifoldco/promptui" "github.com/mattn/go-isatty" "github.com/spf13/cobra" - "golang.org/x/xerrors" "github.com/coder/coder/cli/config" "github.com/coder/coder/coderd" @@ -138,14 +137,9 @@ func isTTY(cmd *cobra.Command) bool { } func prompt(cmd *cobra.Command, prompt *promptui.Prompt) (string, error) { - var ok bool - prompt.Stdin, ok = cmd.InOrStdin().(io.ReadCloser) - if !ok { - return "", xerrors.New("stdin must be a readcloser") - } - prompt.Stdout, ok = cmd.OutOrStdout().(io.WriteCloser) - if !ok { - return "", xerrors.New("stdout must be a readcloser") + prompt.Stdin = io.NopCloser(cmd.InOrStdin()) + prompt.Stdout = readWriteCloser{ + Writer: cmd.OutOrStdout(), } // The prompt library displays defaults in a jarring way for the user @@ -199,3 +193,10 @@ func prompt(cmd *cobra.Command, prompt *promptui.Prompt) (string, error) { return value, err } + +// readWriteCloser fakes reads, writes, and closing! +type readWriteCloser struct { + io.Reader + io.Writer + io.Closer +} diff --git a/cli/workspacecreate_test.go b/cli/workspacecreate_test.go index 306caa65c4b0c..b3b1ca26915f7 100644 --- a/cli/workspacecreate_test.go +++ b/cli/workspacecreate_test.go @@ -3,12 +3,13 @@ package cli_test import ( "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/stretchr/testify/require" + "github.com/coder/coder/pty/ptytest" ) func TestWorkspaceCreate(t *testing.T) { @@ -36,7 +37,9 @@ func TestWorkspaceCreate(t *testing.T) { cmd, root := clitest.New(t, "workspaces", "create", project.Name) clitest.SetupConfig(t, client, root) - cons := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) closeChan := make(chan struct{}) go func() { err := cmd.Execute() @@ -51,13 +54,10 @@ func TestWorkspaceCreate(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := cons.ExpectString(match) - require.NoError(t, err) - _, err = cons.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } - _, err := cons.ExpectString("Create") - require.NoError(t, err) + pty.ExpectMatch("Create") <-closeChan }) } diff --git a/coderd/projectimport_test.go b/coderd/projectimport_test.go index 06140190f51d5..b9df691233576 100644 --- a/coderd/projectimport_test.go +++ b/coderd/projectimport_test.go @@ -5,13 +5,14 @@ import ( "net/http" "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" "github.com/coder/coder/database" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/stretchr/testify/require" ) func TestPostProjectImportByOrganization(t *testing.T) { diff --git a/codersdk/projectimport_test.go b/codersdk/projectimport_test.go index 8cc6b28a23f6c..ccbe01345845a 100644 --- a/codersdk/projectimport_test.go +++ b/codersdk/projectimport_test.go @@ -5,12 +5,13 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/google/uuid" - "github.com/stretchr/testify/require" ) func TestCreateProjectImportJob(t *testing.T) { diff --git a/console/conpty/conpty.go b/console/conpty/conpty.go deleted file mode 100644 index a57264b8ff195..0000000000000 --- a/console/conpty/conpty.go +++ /dev/null @@ -1,107 +0,0 @@ -//go:build windows -// +build windows - -// Original copyright 2020 ActiveState Software. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file - -package conpty - -import ( - "fmt" - "io" - "os" - - "golang.org/x/sys/windows" -) - -// ConPty represents a windows pseudo console. -type ConPty struct { - hpCon windows.Handle - outPipePseudoConsoleSide windows.Handle - outPipeOurSide windows.Handle - inPipeOurSide windows.Handle - inPipePseudoConsoleSide windows.Handle - consoleSize uintptr - outFilePseudoConsoleSide *os.File - outFileOurSide *os.File - inFilePseudoConsoleSide *os.File - inFileOurSide *os.File - closed bool -} - -// New returns a new ConPty pseudo terminal device -func New(columns int16, rows int16) (*ConPty, error) { - c := &ConPty{ - consoleSize: uintptr(columns) + (uintptr(rows) << 16), - } - - return c, c.createPseudoConsoleAndPipes() -} - -// Close closes the pseudo-terminal and cleans up all attached resources -func (c *ConPty) Close() error { - // Trying to close these pipes multiple times will result in an - // access violation - if c.closed { - return nil - } - - err := closePseudoConsole(c.hpCon) - c.outFilePseudoConsoleSide.Close() - c.outFileOurSide.Close() - c.inFilePseudoConsoleSide.Close() - c.inFileOurSide.Close() - c.closed = true - return err -} - -// OutPipe returns the output pipe of the pseudo terminal -func (c *ConPty) OutPipe() *os.File { - return c.outFilePseudoConsoleSide -} - -func (c *ConPty) Reader() io.Reader { - return c.outFileOurSide -} - -// InPipe returns input pipe of the pseudo terminal -// Note: It is safer to use the Write method to prevent partially-written VT sequences -// from corrupting the terminal -func (c *ConPty) InPipe() *os.File { - return c.inFilePseudoConsoleSide -} - -func (c *ConPty) WriteString(str string) (int, error) { - return c.inFileOurSide.WriteString(str) -} - -func (c *ConPty) createPseudoConsoleAndPipes() error { - // Create the stdin pipe - if err := windows.CreatePipe(&c.inPipePseudoConsoleSide, &c.inPipeOurSide, nil, 0); err != nil { - return err - } - - // Create the stdout pipe - if err := windows.CreatePipe(&c.outPipeOurSide, &c.outPipePseudoConsoleSide, nil, 0); err != nil { - return err - } - - // Create the pty with our stdin/stdout - if err := createPseudoConsole(c.consoleSize, c.inPipePseudoConsoleSide, c.outPipePseudoConsoleSide, &c.hpCon); err != nil { - return fmt.Errorf("failed to create pseudo console: %d, %v", uintptr(c.hpCon), err) - } - - c.outFilePseudoConsoleSide = os.NewFile(uintptr(c.outPipePseudoConsoleSide), "|0") - c.outFileOurSide = os.NewFile(uintptr(c.outPipeOurSide), "|1") - - c.inFilePseudoConsoleSide = os.NewFile(uintptr(c.inPipePseudoConsoleSide), "|2") - c.inFileOurSide = os.NewFile(uintptr(c.inPipeOurSide), "|3") - c.closed = false - - return nil -} - -func (c *ConPty) Resize(cols uint16, rows uint16) error { - return resizePseudoConsole(c.hpCon, uintptr(cols)+(uintptr(rows)<<16)) -} diff --git a/console/conpty/syscall.go b/console/conpty/syscall.go deleted file mode 100644 index 284603aa8fdc7..0000000000000 --- a/console/conpty/syscall.go +++ /dev/null @@ -1,53 +0,0 @@ -//go:build windows -// +build windows - -// Copyright 2020 ActiveState Software. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file - -package conpty - -import ( - "unsafe" - - "golang.org/x/sys/windows" -) - -var ( - kernel32 = windows.NewLazySystemDLL("kernel32.dll") - procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole") - procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole") - procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") -) - -func createPseudoConsole(consoleSize uintptr, ptyIn windows.Handle, ptyOut windows.Handle, hpCon *windows.Handle) (err error) { - r1, _, e1 := procCreatePseudoConsole.Call( - consoleSize, - uintptr(ptyIn), - uintptr(ptyOut), - 0, - uintptr(unsafe.Pointer(hpCon)), - ) - - if r1 != 0 { // !S_OK - err = e1 - } - return -} - -func resizePseudoConsole(handle windows.Handle, consoleSize uintptr) (err error) { - r1, _, e1 := procResizePseudoConsole.Call(uintptr(handle), consoleSize) - if r1 != 0 { // !S_OK - err = e1 - } - return -} - -func closePseudoConsole(handle windows.Handle) (err error) { - r1, _, e1 := procClosePseudoConsole.Call(uintptr(handle)) - if r1 == 0 { - err = e1 - } - - return -} diff --git a/console/console.go b/console/console.go deleted file mode 100644 index e5af7fa20977b..0000000000000 --- a/console/console.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console - -import ( - "bufio" - "fmt" - "io" - "io/ioutil" - "log" - "os" - "unicode/utf8" - - "github.com/coder/coder/console/pty" -) - -// Console is an interface to automate input and output for interactive -// applications. Console can block until a specified output is received and send -// input back on it's tty. Console can also multiplex other sources of input -// and multiplex its output to other writers. -type Console struct { - opts Opts - pty pty.Pty - runeReader *bufio.Reader - closers []io.Closer -} - -// Opt allows setting Console options. -type Opt func(*Opts) error - -// Opts provides additional options on creating a Console. -type Opts struct { - Logger *log.Logger - Stdouts []io.Writer - ExpectObservers []Observer -} - -// Observer provides an interface for a function callback that will -// be called after each Expect operation. -// matchers will be the list of active matchers when an error occurred, -// or a list of matchers that matched `buf` when err is nil. -// buf is the captured output that was matched against. -// err is error that might have occurred. May be nil. -type Observer func(matchers []Matcher, buf string, err error) - -// WithStdout adds writers that Console duplicates writes to, similar to the -// Unix tee(1) command. -// -// Each write is written to each listed writer, one at a time. Console is the -// last writer, writing to it's internal buffer for matching expects. -// If a listed writer returns an error, that overall write operation stops and -// returns the error; it does not continue down the list. -func WithStdout(writers ...io.Writer) Opt { - return func(opts *Opts) error { - opts.Stdouts = append(opts.Stdouts, writers...) - return nil - } -} - -// WithLogger adds a logger for Console to log debugging information to. By -// default Console will discard logs. -func WithLogger(logger *log.Logger) Opt { - return func(opts *Opts) error { - opts.Logger = logger - return nil - } -} - -// WithExpectObserver adds an ExpectObserver to allow monitoring Expect operations. -func WithExpectObserver(observers ...Observer) Opt { - return func(opts *Opts) error { - opts.ExpectObservers = append(opts.ExpectObservers, observers...) - return nil - } -} - -// NewConsole returns a new Console with the given options. -func NewConsole(opts ...Opt) (*Console, error) { - options := Opts{ - Logger: log.New(ioutil.Discard, "", 0), - } - - for _, opt := range opts { - if err := opt(&options); err != nil { - return nil, err - } - } - - consolePty, err := pty.New() - if err != nil { - return nil, err - } - closers := []io.Closer{consolePty} - reader := consolePty.Reader() - - cons := &Console{ - opts: options, - pty: consolePty, - runeReader: bufio.NewReaderSize(reader, utf8.UTFMax), - closers: closers, - } - - return cons, nil -} - -// Tty returns an input Tty for accepting input -func (c *Console) InTty() *os.File { - return c.pty.InPipe() -} - -// OutTty returns an output tty for writing -func (c *Console) OutTty() *os.File { - return c.pty.OutPipe() -} - -// Close closes Console's tty. Calling Close will unblock Expect and ExpectEOF. -func (c *Console) Close() error { - for _, fd := range c.closers { - err := fd.Close() - if err != nil { - c.Logf("failed to close: %s", err) - } - } - return nil -} - -// Send writes string s to Console's tty. -func (c *Console) Send(s string) (int, error) { - c.Logf("console send: %q", s) - n, err := c.pty.WriteString(s) - return n, err -} - -// SendLine writes string s to Console's tty with a trailing newline. -func (c *Console) SendLine(s string) (int, error) { - bytes, err := c.Send(fmt.Sprintf("%s\n", s)) - - return bytes, err -} - -// Log prints to Console's logger. -// Arguments are handled in the manner of fmt.Print. -func (c *Console) Log(v ...interface{}) { - c.opts.Logger.Print(v...) -} - -// Logf prints to Console's logger. -// Arguments are handled in the manner of fmt.Printf. -func (c *Console) Logf(format string, v ...interface{}) { - c.opts.Logger.Printf(format, v...) -} diff --git a/console/doc.go b/console/doc.go deleted file mode 100644 index 7a5fc545cd982..0000000000000 --- a/console/doc.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package expect provides an expect-like interface to automate control of -// applications. It is unlike expect in that it does not spawn or manage -// process lifecycle. This package only focuses on expecting output and sending -// input through it's psuedoterminal. -package console diff --git a/console/expect.go b/console/expect.go deleted file mode 100644 index c2e3f583b0a06..0000000000000 --- a/console/expect.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console - -import ( - "bufio" - "bytes" - "fmt" - "io" - "unicode/utf8" -) - -// Expectf reads from the Console's tty until the provided formatted string -// is read or an error occurs, and returns the buffer read by Console. -func (c *Console) Expectf(format string, args ...interface{}) (string, error) { - return c.Expect(String(fmt.Sprintf(format, args...))) -} - -// ExpectString reads from Console's tty until the provided string is read or -// an error occurs, and returns the buffer read by Console. -func (c *Console) ExpectString(s string) (string, error) { - return c.Expect(String(s)) -} - -// Expect reads from Console's tty until a condition specified from opts is -// encountered or an error occurs, and returns the buffer read by console. -// No extra bytes are read once a condition is met, so if a program isn't -// expecting input yet, it will be blocked. Sends are queued up in tty's -// internal buffer so that the next Expect will read the remaining bytes (i.e. -// rest of prompt) as well as its conditions. -func (c *Console) Expect(opts ...ExpectOpt) (string, error) { - var options ExpectOpts - for _, opt := range opts { - if err := opt(&options); err != nil { - return "", err - } - } - - buf := new(bytes.Buffer) - writer := io.MultiWriter(append(c.opts.Stdouts, buf)...) - runeWriter := bufio.NewWriterSize(writer, utf8.UTFMax) - - var matcher Matcher - var err error - - defer func() { - for _, observer := range c.opts.ExpectObservers { - if matcher != nil { - observer([]Matcher{matcher}, buf.String(), err) - return - } - observer(options.Matchers, buf.String(), err) - } - }() - - for { - var r rune - r, _, err = c.runeReader.ReadRune() - if err != nil { - matcher = options.Match(err) - if matcher != nil { - err = nil - break - } - return buf.String(), err - } - - c.Logf("expect read: %q", string(r)) - _, err = runeWriter.WriteRune(r) - if err != nil { - return buf.String(), err - } - - // Immediately flush rune to the underlying writers. - err = runeWriter.Flush() - if err != nil { - return buf.String(), err - } - - matcher = options.Match(buf) - if matcher != nil { - break - } - } - - if matcher != nil { - cb, ok := matcher.(CallbackMatcher) - if ok { - err = cb.Callback(buf) - if err != nil { - return buf.String(), err - } - } - } - - return buf.String(), err -} diff --git a/console/expect_opt.go b/console/expect_opt.go deleted file mode 100644 index fec0d9b8f3e0b..0000000000000 --- a/console/expect_opt.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console - -import ( - "bytes" - "strings" - "time" -) - -// ExpectOpt allows settings Expect options. -type ExpectOpt func(*ExpectOpts) error - -// Callback is a callback function to execute if a match is found for -// the chained matcher. -type Callback func(buf *bytes.Buffer) error - -// ExpectOpts provides additional options on Expect. -type ExpectOpts struct { - Matchers []Matcher - ReadTimeout *time.Duration -} - -// Match sequentially calls Match on all matchers in ExpectOpts and returns the -// first matcher if a match exists, otherwise nil. -func (eo ExpectOpts) Match(v interface{}) Matcher { - for _, matcher := range eo.Matchers { - if matcher.Match(v) { - return matcher - } - } - return nil -} - -// CallbackMatcher is a matcher that provides a Callback function. -type CallbackMatcher interface { - // Callback executes the matcher's callback with the content buffer at the - // time of match. - Callback(buf *bytes.Buffer) error -} - -// Matcher provides an interface for finding a match in content read from -// Console's tty. -type Matcher interface { - // Match returns true iff a match is found. - Match(v interface{}) bool - Criteria() interface{} -} - -// stringMatcher fulfills the Matcher interface to match strings against a given -// bytes.Buffer. -type stringMatcher struct { - str string -} - -func (sm *stringMatcher) Match(v interface{}) bool { - buf, ok := v.(*bytes.Buffer) - if !ok { - return false - } - if strings.Contains(buf.String(), sm.str) { - return true - } - return false -} - -func (sm *stringMatcher) Criteria() interface{} { - return sm.str -} - -// allMatcher fulfills the Matcher interface to match a group of ExpectOpt -// against any value. -type allMatcher struct { - options ExpectOpts -} - -func (am *allMatcher) Match(v interface{}) bool { - var matchers []Matcher - for _, matcher := range am.options.Matchers { - if matcher.Match(v) { - continue - } - matchers = append(matchers, matcher) - } - - am.options.Matchers = matchers - return len(matchers) == 0 -} - -func (am *allMatcher) Criteria() interface{} { - var criteria []interface{} - for _, matcher := range am.options.Matchers { - criteria = append(criteria, matcher.Criteria()) - } - return criteria -} - -// All adds an Expect condition to exit if the content read from Console's tty -// matches all of the provided ExpectOpt, in any order. -func All(expectOpts ...ExpectOpt) ExpectOpt { - return func(opts *ExpectOpts) error { - var options ExpectOpts - for _, opt := range expectOpts { - if err := opt(&options); err != nil { - return err - } - } - - opts.Matchers = append(opts.Matchers, &allMatcher{ - options: options, - }) - return nil - } -} - -// String adds an Expect condition to exit if the content read from Console's -// tty contains any of the given strings. -func String(strs ...string) ExpectOpt { - return func(opts *ExpectOpts) error { - for _, str := range strs { - opts.Matchers = append(opts.Matchers, &stringMatcher{ - str: str, - }) - } - return nil - } -} diff --git a/console/expect_opt_test.go b/console/expect_opt_test.go deleted file mode 100644 index 91efc935fca4e..0000000000000 --- a/console/expect_opt_test.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console_test - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/require" - - . "github.com/coder/coder/console" -) - -func TestExpectOptString(t *testing.T) { - t.Parallel() - - tests := []struct { - title string - opt ExpectOpt - data string - expected bool - }{ - { - "No args", - String(), - "Hello world", - false, - }, - { - "Single arg", - String("Hello"), - "Hello world", - true, - }, - { - "Multiple arg", - String("other", "world"), - "Hello world", - true, - }, - { - "No matches", - String("hello"), - "Hello world", - false, - }, - } - - for _, test := range tests { - test := test - t.Run(test.title, func(t *testing.T) { - t.Parallel() - - var options ExpectOpts - err := test.opt(&options) - require.Nil(t, err) - - buf := new(bytes.Buffer) - _, err = buf.WriteString(test.data) - require.Nil(t, err) - - matcher := options.Match(buf) - if test.expected { - require.NotNil(t, matcher) - } else { - require.Nil(t, matcher) - } - }) - } -} - -func TestExpectOptAll(t *testing.T) { - t.Parallel() - - tests := []struct { - title string - opt ExpectOpt - data string - expected bool - }{ - { - "No opts", - All(), - "Hello world", - true, - }, - { - "Single string match", - All(String("Hello")), - "Hello world", - true, - }, - { - "Single string no match", - All(String("Hello")), - "No match", - false, - }, - { - "Ordered strings match", - All(String("Hello"), String("world")), - "Hello world", - true, - }, - { - "Ordered strings not all match", - All(String("Hello"), String("world")), - "Hello", - false, - }, - { - "Unordered strings", - All(String("world"), String("Hello")), - "Hello world", - true, - }, - { - "Unordered strings not all match", - All(String("world"), String("Hello")), - "Hello", - false, - }, - { - "Repeated strings match", - All(String("Hello"), String("Hello")), - "Hello world", - true, - }, - } - - for _, test := range tests { - test := test - t.Run(test.title, func(t *testing.T) { - t.Parallel() - var options ExpectOpts - err := test.opt(&options) - require.Nil(t, err) - - buf := new(bytes.Buffer) - _, err = buf.WriteString(test.data) - require.Nil(t, err) - - matcher := options.Match(buf) - if test.expected { - require.NotNil(t, matcher) - } else { - require.Nil(t, matcher) - } - }) - } -} diff --git a/console/expect_test.go b/console/expect_test.go deleted file mode 100644 index c80f981717d44..0000000000000 --- a/console/expect_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console_test - -import ( - "bufio" - "errors" - "fmt" - "io" - "runtime/debug" - "strings" - "sync" - "testing" - - "golang.org/x/xerrors" - - . "github.com/coder/coder/console" -) - -var ( - ErrWrongAnswer = xerrors.New("wrong answer") -) - -type Survey struct { - Prompt string - Answer string -} - -func Prompt(in io.Reader, out io.Writer) error { - reader := bufio.NewReader(in) - - for _, survey := range []Survey{ - { - "What is 1+1?", "2", - }, - { - "What is Netflix backwards?", "xilfteN", - }, - } { - _, err := fmt.Fprintf(out, "%s: ", survey.Prompt) - if err != nil { - return err - } - text, err := reader.ReadString('\n') - if err != nil { - return err - } - - _, err = fmt.Fprint(out, text) - if err != nil { - return err - } - text = strings.TrimSpace(text) - if text != survey.Answer { - return ErrWrongAnswer - } - } - - return nil -} - -func newTestConsole(t *testing.T, opts ...Opt) (*Console, error) { - opts = append([]Opt{ - expectNoError(t), - }, opts...) - return NewConsole(opts...) -} - -func expectNoError(t *testing.T) Opt { - return WithExpectObserver( - func(matchers []Matcher, buf string, err error) { - if err == nil { - return - } - if len(matchers) == 0 { - t.Fatalf("Error occurred while matching %q: %s\n%s", buf, err, string(debug.Stack())) - } else { - var criteria []string - for _, matcher := range matchers { - criteria = append(criteria, fmt.Sprintf("%q", matcher.Criteria())) - } - t.Fatalf("Failed to find [%s] in %q: %s\n%s", strings.Join(criteria, ", "), buf, err, string(debug.Stack())) - } - }, - ) -} - -func testCloser(t *testing.T, closer io.Closer) { - if err := closer.Close(); err != nil { - t.Errorf("Close failed: %s", err) - debug.PrintStack() - } -} - -func TestExpectf(t *testing.T) { - t.Parallel() - - console, err := newTestConsole(t) - if err != nil { - t.Errorf("Expected no error but got'%s'", err) - } - defer testCloser(t, console) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - console.Expectf("What is 1+%d?", 1) - console.SendLine("2") - console.Expectf("What is %s backwards?", "Netflix") - console.SendLine("xilfteN") - }() - - err = Prompt(console.InTty(), console.OutTty()) - if err != nil { - t.Errorf("Expected no error but got '%s'", err) - } - wg.Wait() -} - -func TestExpect(t *testing.T) { - t.Parallel() - - console, err := newTestConsole(t) - if err != nil { - t.Errorf("Expected no error but got'%s'", err) - } - defer testCloser(t, console) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - console.ExpectString("What is 1+1?") - console.SendLine("2") - console.ExpectString("What is Netflix backwards?") - console.SendLine("xilfteN") - }() - - err = Prompt(console.InTty(), console.OutTty()) - if err != nil { - t.Errorf("Expected no error but got '%s'", err) - } - wg.Wait() -} - -func TestExpectOutput(t *testing.T) { - t.Parallel() - - console, err := newTestConsole(t) - if err != nil { - t.Errorf("Expected no error but got'%s'", err) - } - defer testCloser(t, console) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - console.ExpectString("What is 1+1?") - console.SendLine("3") - }() - - err = Prompt(console.InTty(), console.OutTty()) - if err == nil || !errors.Is(err, ErrWrongAnswer) { - t.Errorf("Expected error '%s' but got '%s' instead", ErrWrongAnswer, err) - } - wg.Wait() -} diff --git a/console/pty/pty.go b/console/pty/pty.go deleted file mode 100644 index 86b56e68f922e..0000000000000 --- a/console/pty/pty.go +++ /dev/null @@ -1,21 +0,0 @@ -package pty - -import ( - "io" - "os" -) - -// Pty is the minimal pseudo-tty interface we require. -type Pty interface { - InPipe() *os.File - OutPipe() *os.File - Resize(cols uint16, rows uint16) error - WriteString(str string) (int, error) - Reader() io.Reader - Close() error -} - -// New creates a new Pty. -func New() (Pty, error) { - return newPty() -} diff --git a/console/pty/pty_windows.go b/console/pty/pty_windows.go deleted file mode 100644 index 01fbe39169f04..0000000000000 --- a/console/pty/pty_windows.go +++ /dev/null @@ -1,78 +0,0 @@ -//go:build windows -// +build windows - -package pty - -import ( - "io" - "os" - - "golang.org/x/sys/windows" - - "github.com/coder/coder/console/conpty" -) - -func newPty() (Pty, error) { - // We use the CreatePseudoConsole API which was introduced in build 17763 - vsn := windows.RtlGetVersion() - if vsn.MajorVersion < 10 || - vsn.BuildNumber < 17763 { - // If the CreatePseudoConsole API is not available, we fall back to a simpler - // implementation that doesn't create an actual PTY - just uses os.Pipe - return pipePty() - } - - return conpty.New(80, 80) -} - -func pipePty() (Pty, error) { - inFilePipeSide, inFileOurSide, err := os.Pipe() - if err != nil { - return nil, err - } - - outFileOurSide, outFilePipeSide, err := os.Pipe() - if err != nil { - return nil, err - } - - return &pipePtyVal{ - inFilePipeSide, - inFileOurSide, - outFileOurSide, - outFilePipeSide, - }, nil -} - -type pipePtyVal struct { - inFilePipeSide, inFileOurSide *os.File - outFileOurSide, outFilePipeSide *os.File -} - -func (p *pipePtyVal) InPipe() *os.File { - return p.inFilePipeSide -} - -func (p *pipePtyVal) OutPipe() *os.File { - return p.outFilePipeSide -} - -func (p *pipePtyVal) Reader() io.Reader { - return p.outFileOurSide -} - -func (p *pipePtyVal) WriteString(str string) (int, error) { - return p.inFileOurSide.WriteString(str) -} - -func (p *pipePtyVal) Resize(uint16, uint16) error { - return nil -} - -func (p *pipePtyVal) Close() error { - p.inFileOurSide.Close() - p.inFilePipeSide.Close() - p.outFilePipeSide.Close() - p.outFileOurSide.Close() - return nil -} diff --git a/console/test_console.go b/console/test_console.go deleted file mode 100644 index d1d845d6cb4db..0000000000000 --- a/console/test_console.go +++ /dev/null @@ -1,45 +0,0 @@ -package console - -import ( - "bufio" - "io" - "regexp" - "testing" - - "github.com/spf13/cobra" - "github.com/stretchr/testify/require" -) - -var ( - // Used to ensure terminal output doesn't have anything crazy! - // See: https://stackoverflow.com/a/29497680 - stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") -) - -// New creates a new TTY bound to the command provided. -// All ANSI escape codes are stripped to provide clean output. -func New(t *testing.T, cmd *cobra.Command) *Console { - reader, writer := io.Pipe() - scanner := bufio.NewScanner(reader) - t.Cleanup(func() { - _ = reader.Close() - _ = writer.Close() - }) - go func() { - for scanner.Scan() { - if scanner.Err() != nil { - return - } - t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) - } - }() - - console, err := NewConsole(WithStdout(writer)) - require.NoError(t, err) - t.Cleanup(func() { - console.Close() - }) - cmd.SetIn(console.InTty()) - cmd.SetOut(console.OutTty()) - return console -} diff --git a/go.mod b/go.mod index d002eeadbfa77..e224d221c55bc 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/coder/retry v1.3.0 github.com/creack/pty v1.1.17 github.com/fatih/color v1.13.0 + github.com/gliderlabs/ssh v0.3.3 github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/render v1.0.1 github.com/go-playground/validator/v10 v10.10.0 @@ -64,6 +65,7 @@ require ( github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/alecthomas/chroma v0.10.0 // indirect + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/cenkalti/backoff/v4 v4.1.2 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect diff --git a/go.sum b/go.sum index 370ff4aeb0b64..416fe1da0f69d 100644 --- a/go.sum +++ b/go.sum @@ -132,6 +132,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/alexflint/go-filemutex v0.0.0-20171022225611-72bdc8eae2ae/go.mod h1:CgnQgUtFrFz9mxFNtED3jI5tLDjKlOM+oUF/sTk6ps0= github.com/andybalholm/crlf v0.0.0-20171020200849-670099aa064f/go.mod h1:k8feO4+kXDxro6ErPXBRTJ/ro2mf0SsFG8s7doP9kJE= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/arrow v0.0.0-20210818145353-234c94e4ce64/go.mod h1:2qMFB56yOP3KzkB3PbYZ4AlUFg3a88F67TIx5lB/WwY= github.com/apache/arrow/go/arrow v0.0.0-20211013220434-5962184e7a30/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= @@ -441,6 +443,8 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/gliderlabs/ssh v0.3.3 h1:mBQ8NiOgDkINJrZtoizkC3nDNYgSaWtxyem6S2XHBtA= +github.com/gliderlabs/ssh v0.3.3/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= diff --git a/peer/channel.go b/peer/channel.go index d1f4930fe31f7..732a6a1c1de2d 100644 --- a/peer/channel.go +++ b/peer/channel.go @@ -263,6 +263,11 @@ func (c *Channel) Label() string { return c.dc.Label() } +// Protocol returns the protocol of the underlying DataChannel. +func (c *Channel) Protocol() string { + return c.dc.Protocol() +} + // NetConn wraps the DataChannel in a struct fulfilling net.Conn. // Read, Write, and Close operations can still be used on the *Channel struct. func (c *Channel) NetConn() net.Conn { diff --git a/pty/pty.go b/pty/pty.go new file mode 100644 index 0000000000000..0086bfba56c15 --- /dev/null +++ b/pty/pty.go @@ -0,0 +1,39 @@ +package pty + +import ( + "io" +) + +// PTY is a minimal interface for interacting with a TTY. +type PTY interface { + io.Closer + + // Output handles TTY output. + // + // cmd.SetOutput(pty.Output()) would be used to specify a command + // uses the output stream for writing. + // + // The same stream could be read to validate output. + Output() io.ReadWriter + + // Input handles TTY input. + // + // cmd.SetInput(pty.Input()) would be used to specify a command + // uses the PTY input for reading. + // + // The same stream would be used to provide user input: pty.Input().Write(...) + Input() io.ReadWriter + + // Resize sets the size of the PTY. + Resize(cols uint16, rows uint16) error +} + +// New constructs a new Pty. +func New() (PTY, error) { + return newPty() +} + +type readWriter struct { + io.Reader + io.Writer +} diff --git a/console/pty/pty_other.go b/pty/pty_other.go similarity index 52% rename from console/pty/pty_other.go rename to pty/pty_other.go index 723a6dbfd748a..dbdda408b1365 100644 --- a/console/pty/pty_other.go +++ b/pty/pty_other.go @@ -10,46 +10,44 @@ import ( "github.com/creack/pty" ) -func newPty() (Pty, error) { +func newPty() (PTY, error) { ptyFile, ttyFile, err := pty.Open() if err != nil { return nil, err } - return &unixPty{ + return &otherPty{ pty: ptyFile, tty: ttyFile, }, nil } -type unixPty struct { +type otherPty struct { pty, tty *os.File } -func (p *unixPty) InPipe() *os.File { - return p.tty -} - -func (p *unixPty) OutPipe() *os.File { - return p.tty -} - -func (p *unixPty) Reader() io.Reader { - return p.pty +func (p *otherPty) Input() io.ReadWriter { + return readWriter{ + Reader: p.tty, + Writer: p.pty, + } } -func (p *unixPty) WriteString(str string) (int, error) { - return p.pty.WriteString(str) +func (p *otherPty) Output() io.ReadWriter { + return readWriter{ + Reader: p.pty, + Writer: p.tty, + } } -func (p *unixPty) Resize(cols uint16, rows uint16) error { +func (p *otherPty) Resize(cols uint16, rows uint16) error { return pty.Setsize(p.tty, &pty.Winsize{ Rows: rows, Cols: cols, }) } -func (p *unixPty) Close() error { +func (p *otherPty) Close() error { err := p.pty.Close() if err != nil { return err diff --git a/pty/pty_windows.go b/pty/pty_windows.go new file mode 100644 index 0000000000000..b6a9f8ae2e5dd --- /dev/null +++ b/pty/pty_windows.go @@ -0,0 +1,107 @@ +//go:build windows +// +build windows + +package pty + +import ( + "io" + "os" + "sync" + "unsafe" + + "golang.org/x/sys/windows" + + "golang.org/x/xerrors" +) + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole") + procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole") + procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") +) + +// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session +func newPty() (PTY, error) { + // We use the CreatePseudoConsole API which was introduced in build 17763 + vsn := windows.RtlGetVersion() + if vsn.MajorVersion < 10 || + vsn.BuildNumber < 17763 { + // If the CreatePseudoConsole API is not available, we fall back to a simpler + // implementation that doesn't create an actual PTY - just uses os.Pipe + return nil, xerrors.Errorf("pty not supported") + } + + ptyWindows := &ptyWindows{} + + var err error + ptyWindows.inputRead, ptyWindows.inputWrite, err = os.Pipe() + if err != nil { + return nil, err + } + ptyWindows.outputRead, ptyWindows.outputWrite, err = os.Pipe() + + consoleSize := uintptr(80) + (uintptr(80) << 16) + ret, _, err := procCreatePseudoConsole.Call( + consoleSize, + uintptr(ptyWindows.inputRead.Fd()), + uintptr(ptyWindows.outputWrite.Fd()), + 0, + uintptr(unsafe.Pointer(&ptyWindows.console)), + ) + if int32(ret) < 0 { + return nil, xerrors.Errorf("create pseudo console (%d): %w", int32(ret), err) + } + return ptyWindows, nil +} + +type ptyWindows struct { + console windows.Handle + + outputWrite *os.File + outputRead *os.File + inputWrite *os.File + inputRead *os.File + + closeMutex sync.Mutex + closed bool +} + +func (p *ptyWindows) Output() io.ReadWriter { + return readWriter{ + Reader: p.outputRead, + Writer: p.outputWrite, + } +} + +func (p *ptyWindows) Input() io.ReadWriter { + return readWriter{ + Reader: p.inputRead, + Writer: p.inputWrite, + } +} + +func (p *ptyWindows) Resize(cols uint16, rows uint16) error { + ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(cols)+(uintptr(rows)<<16)) + if ret != 0 { + return err + } + return nil +} + +func (p *ptyWindows) Close() error { + p.closeMutex.Lock() + defer p.closeMutex.Unlock() + if p.closed { + return nil + } + p.closed = true + + ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) + if ret != 0 { + return xerrors.Errorf("close pseudo console: %w", err) + } + _ = p.outputRead.Close() + _ = p.inputWrite.Close() + return nil +} diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go new file mode 100644 index 0000000000000..7ea5b7a119f0d --- /dev/null +++ b/pty/ptytest/ptytest.go @@ -0,0 +1,95 @@ +package ptytest + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os/exec" + "regexp" + "strings" + "testing" + "unicode/utf8" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/pty" +) + +var ( + // Used to ensure terminal output doesn't have anything crazy! + // See: https://stackoverflow.com/a/29497680 + stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") +) + +func New(t *testing.T) *PTY { + ptty, err := pty.New() + require.NoError(t, err) + return create(t, ptty) +} + +func Start(t *testing.T, cmd *exec.Cmd) *PTY { + ptty, err := pty.Start(cmd) + require.NoError(t, err) + return create(t, ptty) +} + +func create(t *testing.T, ptty pty.PTY) *PTY { + reader, writer := io.Pipe() + scanner := bufio.NewScanner(reader) + t.Cleanup(func() { + _ = reader.Close() + _ = writer.Close() + }) + go func() { + for scanner.Scan() { + if scanner.Err() != nil { + return + } + t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) + } + }() + + t.Cleanup(func() { + _ = ptty.Close() + }) + return &PTY{ + t: t, + PTY: ptty, + + outputWriter: writer, + runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax), + } +} + +type PTY struct { + t *testing.T + pty.PTY + + outputWriter io.Writer + runeReader *bufio.Reader +} + +func (p *PTY) ExpectMatch(str string) string { + var buffer bytes.Buffer + multiWriter := io.MultiWriter(&buffer, p.outputWriter) + runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax) + for { + var r rune + r, _, err := p.runeReader.ReadRune() + require.NoError(p.t, err) + _, err = runeWriter.WriteRune(r) + require.NoError(p.t, err) + err = runeWriter.Flush() + require.NoError(p.t, err) + if strings.Contains(buffer.String(), str) { + break + } + } + return buffer.String() +} + +func (p *PTY) WriteLine(str string) { + _, err := fmt.Fprintf(p.PTY.Input(), "%s\n", str) + require.NoError(p.t, err) +} diff --git a/pty/ptytest/ptytest_test.go b/pty/ptytest/ptytest_test.go new file mode 100644 index 0000000000000..6603b35ad59db --- /dev/null +++ b/pty/ptytest/ptytest_test.go @@ -0,0 +1,15 @@ +package ptytest_test + +import ( + "testing" + + "github.com/coder/coder/pty/ptytest" +) + +func TestPtytest(t *testing.T) { + t.Parallel() + pty := ptytest.New(t) + pty.Output().Write([]byte("write")) + pty.ExpectMatch("write") + pty.WriteLine("read") +} diff --git a/pty/start.go b/pty/start.go new file mode 100644 index 0000000000000..2b75843ee16c2 --- /dev/null +++ b/pty/start.go @@ -0,0 +1,7 @@ +package pty + +import "os/exec" + +func Start(cmd *exec.Cmd) (PTY, error) { + return startPty(cmd) +} diff --git a/pty/start_other.go b/pty/start_other.go new file mode 100644 index 0000000000000..103f55202efe3 --- /dev/null +++ b/pty/start_other.go @@ -0,0 +1,34 @@ +//go:build !windows +// +build !windows + +package pty + +import ( + "os/exec" + "syscall" + + "github.com/creack/pty" +) + +func startPty(cmd *exec.Cmd) (PTY, error) { + ptty, tty, err := pty.Open() + if err != nil { + return nil, err + } + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + Setctty: true, + } + cmd.Stdout = tty + cmd.Stderr = tty + cmd.Stdin = tty + err = cmd.Start() + if err != nil { + _ = ptty.Close() + return nil, err + } + return &otherPty{ + pty: ptty, + tty: tty, + }, nil +} diff --git a/pty/start_other_test.go b/pty/start_other_test.go new file mode 100644 index 0000000000000..a5e7d94b36af1 --- /dev/null +++ b/pty/start_other_test.go @@ -0,0 +1,25 @@ +//go:build !windows +// +build !windows + +package pty_test + +import ( + "os/exec" + "testing" + + "github.com/coder/coder/pty/ptytest" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestStart(t *testing.T) { + t.Parallel() + t.Run("Echo", func(t *testing.T) { + t.Parallel() + pty := ptytest.Start(t, exec.Command("echo", "test")) + pty.ExpectMatch("test") + }) +} diff --git a/pty/start_windows.go b/pty/start_windows.go new file mode 100644 index 0000000000000..136ba245736ab --- /dev/null +++ b/pty/start_windows.go @@ -0,0 +1,149 @@ +//go:build windows +// +build windows + +package pty + +import ( + "os" + "os/exec" + "strings" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +// Allocates a PTY and starts the specified command attached to it. +// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process +func startPty(cmd *exec.Cmd) (PTY, error) { + fullPath, err := exec.LookPath(cmd.Path) + if err != nil { + return nil, err + } + pathPtr, err := windows.UTF16PtrFromString(fullPath) + if err != nil { + return nil, err + } + argsPtr, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(cmd.Args)) + if err != nil { + return nil, err + } + if cmd.Dir == "" { + cmd.Dir, err = os.Getwd() + if err != nil { + return nil, err + } + } + dirPtr, err := windows.UTF16PtrFromString(cmd.Dir) + if err != nil { + return nil, err + } + pty, err := newPty() + if err != nil { + return nil, err + } + winPty := pty.(*ptyWindows) + + attrs, err := windows.NewProcThreadAttributeList(1) + if err != nil { + return nil, err + } + // Taken from: https://github.com/microsoft/hcsshim/blob/2314362e977aa03b3ed245a4beb12d00422af0e2/internal/winapi/process.go#L6 + err = attrs.Update(0x20016, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console)) + if err != nil { + return nil, err + } + + startupInfo := &windows.StartupInfoEx{} + startupInfo.ProcThreadAttributeList = attrs.List() + startupInfo.StartupInfo.Flags = windows.STARTF_USESTDHANDLES + startupInfo.StartupInfo.Cb = uint32(unsafe.Sizeof(*startupInfo)) + var processInfo windows.ProcessInformation + err = windows.CreateProcess( + pathPtr, + argsPtr, + nil, + nil, + false, + // https://docs.microsoft.com/en-us/windows/win32/procthread/process-creation-flags#create_unicode_environment + windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, cmd.Env))), + dirPtr, + &startupInfo.StartupInfo, + &processInfo, + ) + if err != nil { + return nil, err + } + defer windows.CloseHandle(processInfo.Thread) + defer windows.CloseHandle(processInfo.Process) + + return pty, nil +} + +// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476 +func createEnvBlock(envv []string) *uint16 { + if len(envv) == 0 { + return &utf16.Encode([]rune("\x00\x00"))[0] + } + length := 0 + for _, s := range envv { + length += len(s) + 1 + } + length += 1 + + b := make([]byte, length) + i := 0 + for _, s := range envv { + l := len(s) + copy(b[i:i+l], []byte(s)) + copy(b[i+l:i+l+1], []byte{0}) + i = i + l + 1 + } + copy(b[i:i+1], []byte{0}) + + return &utf16.Encode([]rune(string(b)))[0] +} + +// dedupEnvCase is dedupEnv with a case option for testing. +// If caseInsensitive is true, the case of keys is ignored. +func dedupEnvCase(caseInsensitive bool, env []string) []string { + out := make([]string, 0, len(env)) + saw := make(map[string]int, len(env)) // key => index into out + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + out = append(out, kv) + continue + } + k := kv[:eq] + if caseInsensitive { + k = strings.ToLower(k) + } + if dupIdx, isDup := saw[k]; isDup { + out[dupIdx] = kv + continue + } + saw[k] = len(out) + out = append(out, kv) + } + return out +} + +// addCriticalEnv adds any critical environment variables that are required +// (or at least almost always required) on the operating system. +// Currently this is only used for Windows. +func addCriticalEnv(env []string) []string { + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + continue + } + k := kv[:eq] + if strings.EqualFold(k, "SYSTEMROOT") { + // We already have it. + return env + } + } + return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) +} diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go new file mode 100644 index 0000000000000..faee269776830 --- /dev/null +++ b/pty/start_windows_test.go @@ -0,0 +1,32 @@ +//go:build windows +// +build windows + +package pty_test + +import ( + "os/exec" + "testing" + + "github.com/coder/coder/pty/ptytest" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestStart(t *testing.T) { + t.Parallel() + t.Run("Echo", func(t *testing.T) { + t.Parallel() + pty := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) + pty.ExpectMatch("test") + }) + t.Run("Resize", func(t *testing.T) { + t.Parallel() + pty := ptytest.Start(t, exec.Command("cmd.exe")) + err := pty.Resize(100, 50) + require.NoError(t, err) + }) +}