From dc54a2511f5756231484fced54bea443328c7151 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 7 Nov 2024 10:15:11 +0400 Subject: [PATCH] fix: fix loss of buffered input on cliui.Prompt --- cli/cliui/prompt.go | 37 ++++++++++++++++++---- cli/cliui/prompt_test.go | 68 +++++++++++++++++++++++++++++++--------- 2 files changed, 84 insertions(+), 21 deletions(-) diff --git a/cli/cliui/prompt.go b/cli/cliui/prompt.go index 6057af69b672b..3d1ee4204fb63 100644 --- a/cli/cliui/prompt.go +++ b/cli/cliui/prompt.go @@ -1,10 +1,10 @@ package cliui import ( - "bufio" "bytes" "encoding/json" "fmt" + "io" "os" "os/signal" "strings" @@ -96,14 +96,13 @@ func Prompt(inv *serpent.Invocation, opts PromptOptions) (string, error) { signal.Notify(interrupt, os.Interrupt) defer signal.Stop(interrupt) - reader := bufio.NewReader(inv.Stdin) - line, err = reader.ReadString('\n') + line, err = readUntil(inv.Stdin, '\n') // Check if the first line beings with JSON object or array chars. // This enables multiline JSON to be pasted into an input, and have // it parse properly. if err == nil && (strings.HasPrefix(line, "{") || strings.HasPrefix(line, "[")) { - line, err = promptJSON(reader, line) + line, err = promptJSON(inv.Stdin, line) } } if err != nil { @@ -144,7 +143,7 @@ func Prompt(inv *serpent.Invocation, opts PromptOptions) (string, error) { } } -func promptJSON(reader *bufio.Reader, line string) (string, error) { +func promptJSON(reader io.Reader, line string) (string, error) { var data bytes.Buffer for { _, _ = data.WriteString(line) @@ -162,7 +161,7 @@ func promptJSON(reader *bufio.Reader, line string) (string, error) { // Read line-by-line. We can't use a JSON decoder // here because it doesn't work by newline, so // reads will block. - line, err = reader.ReadString('\n') + line, err = readUntil(reader, '\n') if err != nil { break } @@ -179,3 +178,29 @@ func promptJSON(reader *bufio.Reader, line string) (string, error) { } return line, nil } + +// readUntil the first occurrence of delim in the input, returning a string containing the data up +// to and including the delimiter. Unlike `bufio`, it only reads until the delimiter and no further +// bytes. If readUntil encounters an error before finding a delimiter, it returns the data read +// before the error and the error itself (often io.EOF). readUntil returns err != nil if and only if +// the returned data does not end in delim. +func readUntil(r io.Reader, delim byte) (string, error) { + var ( + have []byte + b = make([]byte, 1) + ) + for { + n, err := r.Read(b) + if n > 0 { + have = append(have, b[0]) + if b[0] == delim { + // match `bufio` in that we only return non-nil if we didn't find the delimiter, + // regardless of whether we also erred. + return string(have), nil + } + } + if err != nil { + return string(have), err + } + } +} diff --git a/cli/cliui/prompt_test.go b/cli/cliui/prompt_test.go index 70f5fdf48a355..58736ca8d16c8 100644 --- a/cli/cliui/prompt_test.go +++ b/cli/cliui/prompt_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/v2/cli/cliui" "github.com/coder/coder/v2/pty" @@ -22,10 +23,11 @@ func TestPrompt(t *testing.T) { t.Parallel() t.Run("Success", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) ptty := ptytest.New(t) msgChan := make(chan string) go func() { - resp, err := newPrompt(ptty, cliui.PromptOptions{ + resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{ Text: "Example", }, nil) assert.NoError(t, err) @@ -33,15 +35,17 @@ func TestPrompt(t *testing.T) { }() ptty.ExpectMatch("Example") ptty.WriteLine("hello") - require.Equal(t, "hello", <-msgChan) + resp := testutil.RequireRecvCtx(ctx, t, msgChan) + require.Equal(t, "hello", resp) }) t.Run("Confirm", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) ptty := ptytest.New(t) doneChan := make(chan string) go func() { - resp, err := newPrompt(ptty, cliui.PromptOptions{ + resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{ Text: "Example", IsConfirm: true, }, nil) @@ -50,18 +54,20 @@ func TestPrompt(t *testing.T) { }() ptty.ExpectMatch("Example") ptty.WriteLine("yes") - require.Equal(t, "yes", <-doneChan) + resp := testutil.RequireRecvCtx(ctx, t, doneChan) + require.Equal(t, "yes", resp) }) t.Run("Skip", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) ptty := ptytest.New(t) var buf bytes.Buffer // Copy all data written out to a buffer. When we close the ptty, we can // no longer read from the ptty.Output(), but we can read what was // written to the buffer. - dataRead, doneReading := context.WithTimeout(context.Background(), testutil.WaitShort) + dataRead, doneReading := context.WithCancel(ctx) go func() { // This will throw an error sometimes. The underlying ptty // has its own cleanup routines in t.Cleanup. Instead of @@ -74,7 +80,7 @@ func TestPrompt(t *testing.T) { doneChan := make(chan string) go func() { - resp, err := newPrompt(ptty, cliui.PromptOptions{ + resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{ Text: "ShouldNotSeeThis", IsConfirm: true, }, func(inv *serpent.Invocation) { @@ -85,7 +91,8 @@ func TestPrompt(t *testing.T) { doneChan <- resp }() - require.Equal(t, "yes", <-doneChan) + resp := testutil.RequireRecvCtx(ctx, t, doneChan) + require.Equal(t, "yes", resp) // Close the reader to end the io.Copy require.NoError(t, ptty.Close(), "close eof reader") // Wait for the IO copy to finish @@ -96,10 +103,11 @@ func TestPrompt(t *testing.T) { }) t.Run("JSON", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) ptty := ptytest.New(t) doneChan := make(chan string) go func() { - resp, err := newPrompt(ptty, cliui.PromptOptions{ + resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{ Text: "Example", }, nil) assert.NoError(t, err) @@ -107,15 +115,17 @@ func TestPrompt(t *testing.T) { }() ptty.ExpectMatch("Example") ptty.WriteLine("{}") - require.Equal(t, "{}", <-doneChan) + resp := testutil.RequireRecvCtx(ctx, t, doneChan) + require.Equal(t, "{}", resp) }) t.Run("BadJSON", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) ptty := ptytest.New(t) doneChan := make(chan string) go func() { - resp, err := newPrompt(ptty, cliui.PromptOptions{ + resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{ Text: "Example", }, nil) assert.NoError(t, err) @@ -123,15 +133,17 @@ func TestPrompt(t *testing.T) { }() ptty.ExpectMatch("Example") ptty.WriteLine("{a") - require.Equal(t, "{a", <-doneChan) + resp := testutil.RequireRecvCtx(ctx, t, doneChan) + require.Equal(t, "{a", resp) }) t.Run("MultilineJSON", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) ptty := ptytest.New(t) doneChan := make(chan string) go func() { - resp, err := newPrompt(ptty, cliui.PromptOptions{ + resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{ Text: "Example", }, nil) assert.NoError(t, err) @@ -141,11 +153,37 @@ func TestPrompt(t *testing.T) { ptty.WriteLine(`{ "test": "wow" }`) - require.Equal(t, `{"test":"wow"}`, <-doneChan) + resp := testutil.RequireRecvCtx(ctx, t, doneChan) + require.Equal(t, `{"test":"wow"}`, resp) + }) + + t.Run("InvalidValid", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + ptty := ptytest.New(t) + doneChan := make(chan string) + go func() { + resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{ + Text: "Example", + Validate: func(s string) error { + t.Logf("validate: %q", s) + if s != "valid" { + return xerrors.New("invalid") + } + return nil + }, + }, nil) + assert.NoError(t, err) + doneChan <- resp + }() + ptty.ExpectMatch("Example") + ptty.WriteLine("foo\nbar\nbaz\n\n\nvalid\n") + resp := testutil.RequireRecvCtx(ctx, t, doneChan) + require.Equal(t, "valid", resp) }) } -func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *serpent.Invocation)) (string, error) { +func newPrompt(ctx context.Context, ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *serpent.Invocation)) (string, error) { value := "" cmd := &serpent.Command{ Handler: func(inv *serpent.Invocation) error { @@ -163,7 +201,7 @@ func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *ser inv.Stdout = ptty.Output() inv.Stderr = ptty.Output() inv.Stdin = ptty.Input() - return value, inv.WithContext(context.Background()).Run() + return value, inv.WithContext(ctx).Run() } func TestPasswordTerminalState(t *testing.T) {