diff --git a/cli/clitest/clitest.go b/cli/clitest/clitest.go index e9fbbd4f23d1d..f696ca0d988e7 100644 --- a/cli/clitest/clitest.go +++ b/cli/clitest/clitest.go @@ -2,16 +2,13 @@ package clitest import ( "archive/tar" - "bufio" "bytes" "errors" "io" "os" "path/filepath" - "regexp" "testing" - "github.com/Netflix/go-expect" "github.com/spf13/cobra" "github.com/stretchr/testify/require" @@ -21,12 +18,6 @@ import ( "github.com/coder/coder/provisioner/echo" ) -var ( - // Used to ensure terminal output doesn't have anything crazy! - // See: https://stackoverflow.com/a/29497680 - stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") -) - // New creates a CLI instance with a configuration pointed to a // temporary testing directory. func New(t *testing.T, args ...string) (*cobra.Command, config.Root) { @@ -55,31 +46,6 @@ func CreateProjectVersionSource(t *testing.T, responses *echo.Responses) string return directory } -// NewConsole creates a new TTY bound to the command provided. -// All ANSI escape codes are stripped to provide clean output. -func NewConsole(t *testing.T, cmd *cobra.Command) *expect.Console { - reader, writer := io.Pipe() - scanner := bufio.NewScanner(reader) - t.Cleanup(func() { - _ = reader.Close() - _ = writer.Close() - }) - go func() { - for scanner.Scan() { - if scanner.Err() != nil { - return - } - t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) - } - }() - - console, err := expect.NewConsole(expect.WithStdout(writer)) - require.NoError(t, err) - cmd.SetIn(console.Tty()) - cmd.SetOut(console.Tty()) - return console -} - func extractTar(t *testing.T, data []byte, directory string) { reader := tar.NewReader(bytes.NewBuffer(data)) for { diff --git a/cli/clitest/clitest_test.go b/cli/clitest/clitest_test.go index 806e04ecc2a4e..f5be5a45db12c 100644 --- a/cli/clitest/clitest_test.go +++ b/cli/clitest/clitest_test.go @@ -1,5 +1,3 @@ -//go:build !windows - package clitest_test import ( @@ -7,6 +5,7 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/expect" "github.com/stretchr/testify/require" "go.uber.org/goleak" ) @@ -21,7 +20,7 @@ func TestCli(t *testing.T) { client := coderdtest.New(t) cmd, config := clitest.New(t) clitest.SetupConfig(t, client, config) - console := clitest.NewConsole(t, cmd) + console := expect.NewTestConsole(t, cmd) go func() { err := cmd.Execute() require.NoError(t, err) diff --git a/cli/login.go b/cli/login.go index a1df7e905d7dd..5910b5846ddcd 100644 --- a/cli/login.go +++ b/cli/login.go @@ -22,6 +22,7 @@ func login() *cobra.Command { Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { rawURL := args[0] + if !strings.HasPrefix(rawURL, "http://") && !strings.HasPrefix(rawURL, "https://") { scheme := "https" if strings.HasPrefix(rawURL, "localhost") { @@ -44,7 +45,7 @@ func login() *cobra.Command { return xerrors.Errorf("has initial user: %w", err) } if !hasInitialUser { - if !isTTY(cmd.InOrStdin()) { + if !isTTY(cmd) { return xerrors.New("the initial user cannot be created in non-interactive mode. use the API") } _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s Your Coder deployment hasn't been set up!\n", color.HiBlackString(">")) diff --git a/cli/login_test.go b/cli/login_test.go index 06f942ee95b9c..43859ba56199c 100644 --- a/cli/login_test.go +++ b/cli/login_test.go @@ -1,11 +1,10 @@ -//go:build !windows - package cli_test import ( "testing" "github.com/coder/coder/cli/clitest" + "github.com/coder/coder/expect" "github.com/coder/coder/coderd/coderdtest" "github.com/stretchr/testify/require" ) @@ -23,8 +22,11 @@ func TestLogin(t *testing.T) { t.Run("InitialUserTTY", func(t *testing.T) { t.Parallel() client := coderdtest.New(t) - root, _ := clitest.New(t, "login", client.URL.String()) - console := clitest.NewConsole(t, root) + // The --force-tty flag is required on Windows, because the `isatty` library does not + // accurately detect Windows ptys when they are not attached to a process: + // https://github.com/mattn/go-isatty/issues/59 + root, _ := clitest.New(t, "login", client.URL.String(), "--force-tty") + console := expect.NewTestConsole(t, root) go func() { err := root.Execute() require.NoError(t, err) diff --git a/cli/projectcreate_test.go b/cli/projectcreate_test.go index ed802475ffe94..0cd654ec657d5 100644 --- a/cli/projectcreate_test.go +++ b/cli/projectcreate_test.go @@ -1,5 +1,3 @@ -//go:build !windows - package cli_test import ( @@ -10,6 +8,7 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/database" + "github.com/coder/coder/expect" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" ) @@ -27,7 +26,7 @@ func TestProjectCreate(t *testing.T) { cmd, root := clitest.New(t, "projects", "create", "--directory", source, "--provisioner", string(database.ProvisionerTypeEcho)) clitest.SetupConfig(t, client, root) _ = coderdtest.NewProvisionerDaemon(t, client) - console := clitest.NewConsole(t, cmd) + console := expect.NewTestConsole(t, cmd) closeChan := make(chan struct{}) go func() { err := cmd.Execute() @@ -74,7 +73,7 @@ func TestProjectCreate(t *testing.T) { cmd, root := clitest.New(t, "projects", "create", "--directory", source, "--provisioner", string(database.ProvisionerTypeEcho)) clitest.SetupConfig(t, client, root) coderdtest.NewProvisionerDaemon(t, client) - console := clitest.NewConsole(t, cmd) + console := expect.NewTestConsole(t, cmd) closeChan := make(chan struct{}) go func() { err := cmd.Execute() diff --git a/cli/root.go b/cli/root.go index 9133d7655d133..f4e27a49d9e67 100644 --- a/cli/root.go +++ b/cli/root.go @@ -21,6 +21,7 @@ import ( const ( varGlobalConfig = "global-config" + varForceTty = "force-tty" ) func Root() *cobra.Command { @@ -65,6 +66,12 @@ func Root() *cobra.Command { cmd.AddCommand(users()) cmd.PersistentFlags().String(varGlobalConfig, configdir.LocalConfig("coder"), "Path to the global `coder` config directory") + cmd.PersistentFlags().Bool(varForceTty, false, "Force the `coder` command to run as if connected to a TTY") + err := cmd.PersistentFlags().MarkHidden(varForceTty) + if err != nil { + // This should never return an error, because we just added the `--force-tty`` flag prior to calling MarkHidden. + panic(err) + } return cmd } @@ -113,7 +120,16 @@ func createConfig(cmd *cobra.Command) config.Root { // isTTY returns whether the passed reader is a TTY or not. // This accepts a reader to work with Cobra's "InOrStdin" // function for simple testing. -func isTTY(reader io.Reader) bool { +func isTTY(cmd *cobra.Command) bool { + // If the `--force-tty` command is available, and set, + // assume we're in a tty. This is primarily for cases on Windows + // where we may not be able to reliably detect this automatically (ie, tests) + forceTty, err := cmd.Flags().GetBool(varForceTty) + if forceTty && err == nil { + return true + } + + reader := cmd.InOrStdin() file, ok := reader.(*os.File) if !ok { return false diff --git a/cli/workspacecreate_test.go b/cli/workspacecreate_test.go index 138e0ee1e61d6..4112223a61a4d 100644 --- a/cli/workspacecreate_test.go +++ b/cli/workspacecreate_test.go @@ -1,5 +1,3 @@ -//go:build !windows - package cli_test import ( @@ -7,6 +5,7 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/expect" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" "github.com/stretchr/testify/require" @@ -37,7 +36,7 @@ func TestWorkspaceCreate(t *testing.T) { cmd, root := clitest.New(t, "workspaces", "create", project.Name) clitest.SetupConfig(t, client, root) - console := clitest.NewConsole(t, cmd) + console := expect.NewTestConsole(t, cmd) closeChan := make(chan struct{}) go func() { err := cmd.Execute() diff --git a/expect/conpty/conpty.go b/expect/conpty/conpty.go new file mode 100644 index 0000000000000..a57264b8ff195 --- /dev/null +++ b/expect/conpty/conpty.go @@ -0,0 +1,107 @@ +//go:build windows +// +build windows + +// Original copyright 2020 ActiveState Software. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file + +package conpty + +import ( + "fmt" + "io" + "os" + + "golang.org/x/sys/windows" +) + +// ConPty represents a windows pseudo console. +type ConPty struct { + hpCon windows.Handle + outPipePseudoConsoleSide windows.Handle + outPipeOurSide windows.Handle + inPipeOurSide windows.Handle + inPipePseudoConsoleSide windows.Handle + consoleSize uintptr + outFilePseudoConsoleSide *os.File + outFileOurSide *os.File + inFilePseudoConsoleSide *os.File + inFileOurSide *os.File + closed bool +} + +// New returns a new ConPty pseudo terminal device +func New(columns int16, rows int16) (*ConPty, error) { + c := &ConPty{ + consoleSize: uintptr(columns) + (uintptr(rows) << 16), + } + + return c, c.createPseudoConsoleAndPipes() +} + +// Close closes the pseudo-terminal and cleans up all attached resources +func (c *ConPty) Close() error { + // Trying to close these pipes multiple times will result in an + // access violation + if c.closed { + return nil + } + + err := closePseudoConsole(c.hpCon) + c.outFilePseudoConsoleSide.Close() + c.outFileOurSide.Close() + c.inFilePseudoConsoleSide.Close() + c.inFileOurSide.Close() + c.closed = true + return err +} + +// OutPipe returns the output pipe of the pseudo terminal +func (c *ConPty) OutPipe() *os.File { + return c.outFilePseudoConsoleSide +} + +func (c *ConPty) Reader() io.Reader { + return c.outFileOurSide +} + +// InPipe returns input pipe of the pseudo terminal +// Note: It is safer to use the Write method to prevent partially-written VT sequences +// from corrupting the terminal +func (c *ConPty) InPipe() *os.File { + return c.inFilePseudoConsoleSide +} + +func (c *ConPty) WriteString(str string) (int, error) { + return c.inFileOurSide.WriteString(str) +} + +func (c *ConPty) createPseudoConsoleAndPipes() error { + // Create the stdin pipe + if err := windows.CreatePipe(&c.inPipePseudoConsoleSide, &c.inPipeOurSide, nil, 0); err != nil { + return err + } + + // Create the stdout pipe + if err := windows.CreatePipe(&c.outPipeOurSide, &c.outPipePseudoConsoleSide, nil, 0); err != nil { + return err + } + + // Create the pty with our stdin/stdout + if err := createPseudoConsole(c.consoleSize, c.inPipePseudoConsoleSide, c.outPipePseudoConsoleSide, &c.hpCon); err != nil { + return fmt.Errorf("failed to create pseudo console: %d, %v", uintptr(c.hpCon), err) + } + + c.outFilePseudoConsoleSide = os.NewFile(uintptr(c.outPipePseudoConsoleSide), "|0") + c.outFileOurSide = os.NewFile(uintptr(c.outPipeOurSide), "|1") + + c.inFilePseudoConsoleSide = os.NewFile(uintptr(c.inPipePseudoConsoleSide), "|2") + c.inFileOurSide = os.NewFile(uintptr(c.inPipeOurSide), "|3") + c.closed = false + + return nil +} + +func (c *ConPty) Resize(cols uint16, rows uint16) error { + return resizePseudoConsole(c.hpCon, uintptr(cols)+(uintptr(rows)<<16)) +} diff --git a/expect/conpty/syscall.go b/expect/conpty/syscall.go new file mode 100644 index 0000000000000..284603aa8fdc7 --- /dev/null +++ b/expect/conpty/syscall.go @@ -0,0 +1,53 @@ +//go:build windows +// +build windows + +// Copyright 2020 ActiveState Software. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file + +package conpty + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole") + procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole") + procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") +) + +func createPseudoConsole(consoleSize uintptr, ptyIn windows.Handle, ptyOut windows.Handle, hpCon *windows.Handle) (err error) { + r1, _, e1 := procCreatePseudoConsole.Call( + consoleSize, + uintptr(ptyIn), + uintptr(ptyOut), + 0, + uintptr(unsafe.Pointer(hpCon)), + ) + + if r1 != 0 { // !S_OK + err = e1 + } + return +} + +func resizePseudoConsole(handle windows.Handle, consoleSize uintptr) (err error) { + r1, _, e1 := procResizePseudoConsole.Call(uintptr(handle), consoleSize) + if r1 != 0 { // !S_OK + err = e1 + } + return +} + +func closePseudoConsole(handle windows.Handle) (err error) { + r1, _, e1 := procClosePseudoConsole.Call(uintptr(handle)) + if r1 == 0 { + err = e1 + } + + return +} diff --git a/expect/console.go b/expect/console.go new file mode 100644 index 0000000000000..3a9592cce7ba0 --- /dev/null +++ b/expect/console.go @@ -0,0 +1,163 @@ +// Copyright 2018 Netflix, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expect + +import ( + "bufio" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "unicode/utf8" + + "github.com/coder/coder/expect/pty" +) + +// Console is an interface to automate input and output for interactive +// applications. Console can block until a specified output is received and send +// input back on it's tty. Console can also multiplex other sources of input +// and multiplex its output to other writers. +type Console struct { + opts ConsoleOpts + pty pty.Pty + runeReader *bufio.Reader + closers []io.Closer +} + +// ConsoleOpt allows setting Console options. +type ConsoleOpt func(*ConsoleOpts) error + +// ConsoleOpts provides additional options on creating a Console. +type ConsoleOpts struct { + Logger *log.Logger + Stdouts []io.Writer + ExpectObservers []Observer +} + +// Observer provides an interface for a function callback that will +// be called after each Expect operation. +// matchers will be the list of active matchers when an error occurred, +// or a list of matchers that matched `buf` when err is nil. +// buf is the captured output that was matched against. +// err is error that might have occurred. May be nil. +type Observer func(matchers []Matcher, buf string, err error) + +// WithStdout adds writers that Console duplicates writes to, similar to the +// Unix tee(1) command. +// +// Each write is written to each listed writer, one at a time. Console is the +// last writer, writing to it's internal buffer for matching expects. +// If a listed writer returns an error, that overall write operation stops and +// returns the error; it does not continue down the list. +func WithStdout(writers ...io.Writer) ConsoleOpt { + return func(opts *ConsoleOpts) error { + opts.Stdouts = append(opts.Stdouts, writers...) + return nil + } +} + +// WithLogger adds a logger for Console to log debugging information to. By +// default Console will discard logs. +func WithLogger(logger *log.Logger) ConsoleOpt { + return func(opts *ConsoleOpts) error { + opts.Logger = logger + return nil + } +} + +// WithExpectObserver adds an ExpectObserver to allow monitoring Expect operations. +func WithExpectObserver(observers ...Observer) ConsoleOpt { + return func(opts *ConsoleOpts) error { + opts.ExpectObservers = append(opts.ExpectObservers, observers...) + return nil + } +} + +// NewConsole returns a new Console with the given options. +func NewConsole(opts ...ConsoleOpt) (*Console, error) { + options := ConsoleOpts{ + Logger: log.New(ioutil.Discard, "", 0), + } + + for _, opt := range opts { + if err := opt(&options); err != nil { + return nil, err + } + } + + consolePty, err := pty.New() + if err != nil { + return nil, err + } + closers := []io.Closer{consolePty} + reader := consolePty.Reader() + + console := &Console{ + opts: options, + pty: consolePty, + runeReader: bufio.NewReaderSize(reader, utf8.UTFMax), + closers: closers, + } + + return console, nil +} + +// Tty returns an input Tty for accepting input +func (c *Console) InTty() *os.File { + return c.pty.InPipe() +} + +// OutTty returns an output tty for writing +func (c *Console) OutTty() *os.File { + return c.pty.OutPipe() +} + +// Close closes Console's tty. Calling Close will unblock Expect and ExpectEOF. +func (c *Console) Close() error { + for _, fd := range c.closers { + err := fd.Close() + if err != nil { + c.Logf("failed to close: %s", err) + } + } + return nil +} + +// Send writes string s to Console's tty. +func (c *Console) Send(s string) (int, error) { + c.Logf("console send: %q", s) + n, err := c.pty.WriteString(s) + return n, err +} + +// SendLine writes string s to Console's tty with a trailing newline. +func (c *Console) SendLine(s string) (int, error) { + bytes, err := c.Send(fmt.Sprintf("%s\n", s)) + + return bytes, err +} + +// Log prints to Console's logger. +// Arguments are handled in the manner of fmt.Print. +func (c *Console) Log(v ...interface{}) { + c.opts.Logger.Print(v...) +} + +// Logf prints to Console's logger. +// Arguments are handled in the manner of fmt.Printf. +func (c *Console) Logf(format string, v ...interface{}) { + c.opts.Logger.Printf(format, v...) +} diff --git a/expect/doc.go b/expect/doc.go new file mode 100644 index 0000000000000..a0163f0e508d5 --- /dev/null +++ b/expect/doc.go @@ -0,0 +1,19 @@ +// Copyright 2018 Netflix, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package expect provides an expect-like interface to automate control of +// applications. It is unlike expect in that it does not spawn or manage +// process lifecycle. This package only focuses on expecting output and sending +// input through it's psuedoterminal. +package expect diff --git a/expect/expect.go b/expect/expect.go new file mode 100644 index 0000000000000..be266ca049434 --- /dev/null +++ b/expect/expect.go @@ -0,0 +1,109 @@ +// Copyright 2018 Netflix, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expect + +import ( + "bufio" + "bytes" + "fmt" + "io" + "unicode/utf8" +) + +// Expectf reads from the Console's tty until the provided formatted string +// is read or an error occurs, and returns the buffer read by Console. +func (c *Console) Expectf(format string, args ...interface{}) (string, error) { + return c.Expect(String(fmt.Sprintf(format, args...))) +} + +// ExpectString reads from Console's tty until the provided string is read or +// an error occurs, and returns the buffer read by Console. +func (c *Console) ExpectString(s string) (string, error) { + return c.Expect(String(s)) +} + +// Expect reads from Console's tty until a condition specified from opts is +// encountered or an error occurs, and returns the buffer read by console. +// No extra bytes are read once a condition is met, so if a program isn't +// expecting input yet, it will be blocked. Sends are queued up in tty's +// internal buffer so that the next Expect will read the remaining bytes (i.e. +// rest of prompt) as well as its conditions. +func (c *Console) Expect(opts ...Opt) (string, error) { + var options Opts + for _, opt := range opts { + if err := opt(&options); err != nil { + return "", err + } + } + + buf := new(bytes.Buffer) + writer := io.MultiWriter(append(c.opts.Stdouts, buf)...) + runeWriter := bufio.NewWriterSize(writer, utf8.UTFMax) + + var matcher Matcher + var err error + + defer func() { + for _, observer := range c.opts.ExpectObservers { + if matcher != nil { + observer([]Matcher{matcher}, buf.String(), err) + return + } + observer(options.Matchers, buf.String(), err) + } + }() + + for { + var r rune + r, _, err = c.runeReader.ReadRune() + if err != nil { + matcher = options.Match(err) + if matcher != nil { + err = nil + break + } + return buf.String(), err + } + + c.Logf("expect read: %q", string(r)) + _, err = runeWriter.WriteRune(r) + if err != nil { + return buf.String(), err + } + + // Immediately flush rune to the underlying writers. + err = runeWriter.Flush() + if err != nil { + return buf.String(), err + } + + matcher = options.Match(buf) + if matcher != nil { + break + } + } + + if matcher != nil { + cb, ok := matcher.(CallbackMatcher) + if ok { + err = cb.Callback(buf) + if err != nil { + return buf.String(), err + } + } + } + + return buf.String(), err +} diff --git a/expect/expect_opt.go b/expect/expect_opt.go new file mode 100644 index 0000000000000..9262d6df12f57 --- /dev/null +++ b/expect/expect_opt.go @@ -0,0 +1,139 @@ +// Copyright 2018 Netflix, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expect + +import ( + "bytes" + "strings" + "time" +) + +// Opt allows settings Expect options. +type Opt func(*Opts) error + +// ConsoleCallback is a callback function to execute if a match is found for +// the chained matcher. +type ConsoleCallback func(buf *bytes.Buffer) error + +// Opts provides additional options on Expect. +type Opts struct { + Matchers []Matcher + ReadTimeout *time.Duration +} + +// Match sequentially calls Match on all matchers in ExpectOpts and returns the +// first matcher if a match exists, otherwise nil. +func (eo Opts) Match(v interface{}) Matcher { + for _, matcher := range eo.Matchers { + if matcher.Match(v) { + return matcher + } + } + return nil +} + +// CallbackMatcher is a matcher that provides a Callback function. +type CallbackMatcher interface { + // Callback executes the matcher's callback with the content buffer at the + // time of match. + Callback(buf *bytes.Buffer) error +} + +// Matcher provides an interface for finding a match in content read from +// Console's tty. +type Matcher interface { + // Match returns true iff a match is found. + Match(v interface{}) bool + Criteria() interface{} +} + +// stringMatcher fulfills the Matcher interface to match strings against a given +// bytes.Buffer. +type stringMatcher struct { + str string +} + +func (sm *stringMatcher) Match(v interface{}) bool { + buf, ok := v.(*bytes.Buffer) + if !ok { + return false + } + if strings.Contains(buf.String(), sm.str) { + return true + } + return false +} + +func (sm *stringMatcher) Criteria() interface{} { + return sm.str +} + +// allMatcher fulfills the Matcher interface to match a group of ExpectOpt +// against any value. +type allMatcher struct { + options Opts +} + +func (am *allMatcher) Match(v interface{}) bool { + var matchers []Matcher + for _, matcher := range am.options.Matchers { + if matcher.Match(v) { + continue + } + matchers = append(matchers, matcher) + } + + am.options.Matchers = matchers + return len(matchers) == 0 +} + +func (am *allMatcher) Criteria() interface{} { + var criteria []interface{} + for _, matcher := range am.options.Matchers { + criteria = append(criteria, matcher.Criteria()) + } + return criteria +} + +// All adds an Expect condition to exit if the content read from Console's tty +// matches all of the provided ExpectOpt, in any order. +func All(expectOpts ...Opt) Opt { + return func(opts *Opts) error { + var options Opts + for _, opt := range expectOpts { + if err := opt(&options); err != nil { + return err + } + } + + opts.Matchers = append(opts.Matchers, &allMatcher{ + options: options, + }) + return nil + } +} + +// String adds an Expect condition to exit if the content read from Console's +// tty contains any of the given strings. +func String(strs ...string) Opt { + return func(opts *Opts) error { + for _, str := range strs { + opts.Matchers = append(opts.Matchers, &stringMatcher{ + str: str, + }) + } + return nil + } +} diff --git a/expect/expect_opt_test.go b/expect/expect_opt_test.go new file mode 100644 index 0000000000000..e9f5aba95d603 --- /dev/null +++ b/expect/expect_opt_test.go @@ -0,0 +1,163 @@ +// Copyright 2018 Netflix, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expect_test + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + . "github.com/coder/coder/expect" +) + +func TestExpectOptString(t *testing.T) { + t.Parallel() + + tests := []struct { + title string + opt Opt + data string + expected bool + }{ + { + "No args", + String(), + "Hello world", + false, + }, + { + "Single arg", + String("Hello"), + "Hello world", + true, + }, + { + "Multiple arg", + String("other", "world"), + "Hello world", + true, + }, + { + "No matches", + String("hello"), + "Hello world", + false, + }, + } + + for _, test := range tests { + test := test + t.Run(test.title, func(t *testing.T) { + t.Parallel() + + var options Opts + err := test.opt(&options) + require.Nil(t, err) + + buf := new(bytes.Buffer) + _, err = buf.WriteString(test.data) + require.Nil(t, err) + + matcher := options.Match(buf) + if test.expected { + require.NotNil(t, matcher) + } else { + require.Nil(t, matcher) + } + }) + } +} + +func TestExpectOptAll(t *testing.T) { + t.Parallel() + + tests := []struct { + title string + opt Opt + data string + expected bool + }{ + { + "No opts", + All(), + "Hello world", + true, + }, + { + "Single string match", + All(String("Hello")), + "Hello world", + true, + }, + { + "Single string no match", + All(String("Hello")), + "No match", + false, + }, + { + "Ordered strings match", + All(String("Hello"), String("world")), + "Hello world", + true, + }, + { + "Ordered strings not all match", + All(String("Hello"), String("world")), + "Hello", + false, + }, + { + "Unordered strings", + All(String("world"), String("Hello")), + "Hello world", + true, + }, + { + "Unordered strings not all match", + All(String("world"), String("Hello")), + "Hello", + false, + }, + { + "Repeated strings match", + All(String("Hello"), String("Hello")), + "Hello world", + true, + }, + } + + for _, test := range tests { + test := test + t.Run(test.title, func(t *testing.T) { + t.Parallel() + var options Opts + err := test.opt(&options) + require.Nil(t, err) + + buf := new(bytes.Buffer) + _, err = buf.WriteString(test.data) + require.Nil(t, err) + + matcher := options.Match(buf) + if test.expected { + require.NotNil(t, matcher) + } else { + require.Nil(t, matcher) + } + }) + } +} diff --git a/expect/expect_test.go b/expect/expect_test.go new file mode 100644 index 0000000000000..f74fc781f2d94 --- /dev/null +++ b/expect/expect_test.go @@ -0,0 +1,181 @@ +// Copyright 2018 Netflix, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expect_test + +import ( + "bufio" + "errors" + "fmt" + "io" + "runtime/debug" + "strings" + "sync" + "testing" + + "golang.org/x/xerrors" + + . "github.com/coder/coder/expect" +) + +var ( + ErrWrongAnswer = xerrors.New("wrong answer") +) + +type Survey struct { + Prompt string + Answer string +} + +func Prompt(in io.Reader, out io.Writer) error { + reader := bufio.NewReader(in) + + for _, survey := range []Survey{ + { + "What is 1+1?", "2", + }, + { + "What is Netflix backwards?", "xilfteN", + }, + } { + _, err := fmt.Fprintf(out, "%s: ", survey.Prompt) + if err != nil { + return err + } + text, err := reader.ReadString('\n') + if err != nil { + return err + } + + _, err = fmt.Fprint(out, text) + if err != nil { + return err + } + text = strings.TrimSpace(text) + if text != survey.Answer { + return ErrWrongAnswer + } + } + + return nil +} + +func newTestConsole(t *testing.T, opts ...ConsoleOpt) (*Console, error) { + opts = append([]ConsoleOpt{ + expectNoError(t), + }, opts...) + return NewConsole(opts...) +} + +func expectNoError(t *testing.T) ConsoleOpt { + return WithExpectObserver( + func(matchers []Matcher, buf string, err error) { + if err == nil { + return + } + if len(matchers) == 0 { + t.Fatalf("Error occurred while matching %q: %s\n%s", buf, err, string(debug.Stack())) + } else { + var criteria []string + for _, matcher := range matchers { + criteria = append(criteria, fmt.Sprintf("%q", matcher.Criteria())) + } + t.Fatalf("Failed to find [%s] in %q: %s\n%s", strings.Join(criteria, ", "), buf, err, string(debug.Stack())) + } + }, + ) +} + +func testCloser(t *testing.T, closer io.Closer) { + if err := closer.Close(); err != nil { + t.Errorf("Close failed: %s", err) + debug.PrintStack() + } +} + +func TestExpectf(t *testing.T) { + t.Parallel() + + console, err := newTestConsole(t) + if err != nil { + t.Errorf("Expected no error but got'%s'", err) + } + defer testCloser(t, console) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + console.Expectf("What is 1+%d?", 1) + console.SendLine("2") + console.Expectf("What is %s backwards?", "Netflix") + console.SendLine("xilfteN") + }() + + err = Prompt(console.InTty(), console.OutTty()) + if err != nil { + t.Errorf("Expected no error but got '%s'", err) + } + wg.Wait() +} + +func TestExpect(t *testing.T) { + t.Parallel() + + console, err := newTestConsole(t) + if err != nil { + t.Errorf("Expected no error but got'%s'", err) + } + defer testCloser(t, console) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + console.ExpectString("What is 1+1?") + console.SendLine("2") + console.ExpectString("What is Netflix backwards?") + console.SendLine("xilfteN") + }() + + err = Prompt(console.InTty(), console.OutTty()) + if err != nil { + t.Errorf("Expected no error but got '%s'", err) + } + wg.Wait() +} + +func TestExpectOutput(t *testing.T) { + t.Parallel() + + console, err := newTestConsole(t) + if err != nil { + t.Errorf("Expected no error but got'%s'", err) + } + defer testCloser(t, console) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + console.ExpectString("What is 1+1?") + console.SendLine("3") + }() + + err = Prompt(console.InTty(), console.OutTty()) + if err == nil || !errors.Is(err, ErrWrongAnswer) { + t.Errorf("Expected error '%s' but got '%s' instead", ErrWrongAnswer, err) + } + wg.Wait() +} diff --git a/expect/pty/pty.go b/expect/pty/pty.go new file mode 100644 index 0000000000000..86b56e68f922e --- /dev/null +++ b/expect/pty/pty.go @@ -0,0 +1,21 @@ +package pty + +import ( + "io" + "os" +) + +// Pty is the minimal pseudo-tty interface we require. +type Pty interface { + InPipe() *os.File + OutPipe() *os.File + Resize(cols uint16, rows uint16) error + WriteString(str string) (int, error) + Reader() io.Reader + Close() error +} + +// New creates a new Pty. +func New() (Pty, error) { + return newPty() +} diff --git a/expect/pty/pty_other.go b/expect/pty/pty_other.go new file mode 100644 index 0000000000000..723a6dbfd748a --- /dev/null +++ b/expect/pty/pty_other.go @@ -0,0 +1,63 @@ +//go:build !windows +// +build !windows + +package pty + +import ( + "io" + "os" + + "github.com/creack/pty" +) + +func newPty() (Pty, error) { + ptyFile, ttyFile, err := pty.Open() + if err != nil { + return nil, err + } + + return &unixPty{ + pty: ptyFile, + tty: ttyFile, + }, nil +} + +type unixPty struct { + pty, tty *os.File +} + +func (p *unixPty) InPipe() *os.File { + return p.tty +} + +func (p *unixPty) OutPipe() *os.File { + return p.tty +} + +func (p *unixPty) Reader() io.Reader { + return p.pty +} + +func (p *unixPty) WriteString(str string) (int, error) { + return p.pty.WriteString(str) +} + +func (p *unixPty) Resize(cols uint16, rows uint16) error { + return pty.Setsize(p.tty, &pty.Winsize{ + Rows: rows, + Cols: cols, + }) +} + +func (p *unixPty) Close() error { + err := p.pty.Close() + if err != nil { + return err + } + + err = p.tty.Close() + if err != nil { + return err + } + return nil +} diff --git a/expect/pty/pty_windows.go b/expect/pty/pty_windows.go new file mode 100644 index 0000000000000..1d8645840516d --- /dev/null +++ b/expect/pty/pty_windows.go @@ -0,0 +1,78 @@ +//go:build windows +// +build windows + +package pty + +import ( + "io" + "os" + + "golang.org/x/sys/windows" + + "github.com/coder/coder/expect/conpty" +) + +func newPty() (Pty, error) { + // We use the CreatePseudoConsole API which was introduced in build 17763 + vsn := windows.RtlGetVersion() + if vsn.MajorVersion < 10 || + vsn.BuildNumber < 17763 { + // If the CreatePseudoConsole API is not available, we fall back to a simpler + // implementation that doesn't create an actual PTY - just uses os.Pipe + return pipePty() + } + + return conpty.New(80, 80) +} + +func pipePty() (Pty, error) { + inFilePipeSide, inFileOurSide, err := os.Pipe() + if err != nil { + return nil, err + } + + outFileOurSide, outFilePipeSide, err := os.Pipe() + if err != nil { + return nil, err + } + + return &pipePtyVal{ + inFilePipeSide, + inFileOurSide, + outFileOurSide, + outFilePipeSide, + }, nil +} + +type pipePtyVal struct { + inFilePipeSide, inFileOurSide *os.File + outFileOurSide, outFilePipeSide *os.File +} + +func (p *pipePtyVal) InPipe() *os.File { + return p.inFilePipeSide +} + +func (p *pipePtyVal) OutPipe() *os.File { + return p.outFilePipeSide +} + +func (p *pipePtyVal) Reader() io.Reader { + return p.outFileOurSide +} + +func (p *pipePtyVal) WriteString(str string) (int, error) { + return p.inFileOurSide.WriteString(str) +} + +func (p *pipePtyVal) Resize(uint16, uint16) error { + return nil +} + +func (p *pipePtyVal) Close() error { + p.inFileOurSide.Close() + p.inFilePipeSide.Close() + p.outFilePipeSide.Close() + p.outFileOurSide.Close() + return nil +} diff --git a/expect/test_console.go b/expect/test_console.go new file mode 100644 index 0000000000000..e7d8c2a87a743 --- /dev/null +++ b/expect/test_console.go @@ -0,0 +1,45 @@ +package expect + +import ( + "bufio" + "io" + "regexp" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" +) + +var ( + // Used to ensure terminal output doesn't have anything crazy! + // See: https://stackoverflow.com/a/29497680 + stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") +) + +// NewTestConsole creates a new TTY bound to the command provided. +// All ANSI escape codes are stripped to provide clean output. +func NewTestConsole(t *testing.T, cmd *cobra.Command) *Console { + reader, writer := io.Pipe() + scanner := bufio.NewScanner(reader) + t.Cleanup(func() { + _ = reader.Close() + _ = writer.Close() + }) + go func() { + for scanner.Scan() { + if scanner.Err() != nil { + return + } + t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) + } + }() + + console, err := NewConsole(WithStdout(writer)) + require.NoError(t, err) + t.Cleanup(func() { + console.Close() + }) + cmd.SetIn(console.InTty()) + cmd.SetOut(console.OutTty()) + return console +}