Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 71dc91e

Browse files
authored
fix: fix loss of buffered input on cliui.Prompt (#15421)
fixes coder/internal#203
1 parent 0987281 commit 71dc91e

File tree

2 files changed

+84
-21
lines changed

2 files changed

+84
-21
lines changed

cli/cliui/prompt.go

+31-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package cliui
22

33
import (
4-
"bufio"
54
"bytes"
65
"encoding/json"
76
"fmt"
7+
"io"
88
"os"
99
"os/signal"
1010
"strings"
@@ -96,14 +96,13 @@ func Prompt(inv *serpent.Invocation, opts PromptOptions) (string, error) {
9696
signal.Notify(interrupt, os.Interrupt)
9797
defer signal.Stop(interrupt)
9898

99-
reader := bufio.NewReader(inv.Stdin)
100-
line, err = reader.ReadString('\n')
99+
line, err = readUntil(inv.Stdin, '\n')
101100

102101
// Check if the first line beings with JSON object or array chars.
103102
// This enables multiline JSON to be pasted into an input, and have
104103
// it parse properly.
105104
if err == nil && (strings.HasPrefix(line, "{") || strings.HasPrefix(line, "[")) {
106-
line, err = promptJSON(reader, line)
105+
line, err = promptJSON(inv.Stdin, line)
107106
}
108107
}
109108
if err != nil {
@@ -144,7 +143,7 @@ func Prompt(inv *serpent.Invocation, opts PromptOptions) (string, error) {
144143
}
145144
}
146145

147-
func promptJSON(reader *bufio.Reader, line string) (string, error) {
146+
func promptJSON(reader io.Reader, line string) (string, error) {
148147
var data bytes.Buffer
149148
for {
150149
_, _ = data.WriteString(line)
@@ -162,7 +161,7 @@ func promptJSON(reader *bufio.Reader, line string) (string, error) {
162161
// Read line-by-line. We can't use a JSON decoder
163162
// here because it doesn't work by newline, so
164163
// reads will block.
165-
line, err = reader.ReadString('\n')
164+
line, err = readUntil(reader, '\n')
166165
if err != nil {
167166
break
168167
}
@@ -179,3 +178,29 @@ func promptJSON(reader *bufio.Reader, line string) (string, error) {
179178
}
180179
return line, nil
181180
}
181+
182+
// readUntil the first occurrence of delim in the input, returning a string containing the data up
183+
// to and including the delimiter. Unlike `bufio`, it only reads until the delimiter and no further
184+
// bytes. If readUntil encounters an error before finding a delimiter, it returns the data read
185+
// before the error and the error itself (often io.EOF). readUntil returns err != nil if and only if
186+
// the returned data does not end in delim.
187+
func readUntil(r io.Reader, delim byte) (string, error) {
188+
var (
189+
have []byte
190+
b = make([]byte, 1)
191+
)
192+
for {
193+
n, err := r.Read(b)
194+
if n > 0 {
195+
have = append(have, b[0])
196+
if b[0] == delim {
197+
// match `bufio` in that we only return non-nil if we didn't find the delimiter,
198+
// regardless of whether we also erred.
199+
return string(have), nil
200+
}
201+
}
202+
if err != nil {
203+
return string(have), err
204+
}
205+
}
206+
}

cli/cliui/prompt_test.go

+53-15
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/stretchr/testify/assert"
1212
"github.com/stretchr/testify/require"
13+
"golang.org/x/xerrors"
1314

1415
"github.com/coder/coder/v2/cli/cliui"
1516
"github.com/coder/coder/v2/pty"
@@ -22,26 +23,29 @@ func TestPrompt(t *testing.T) {
2223
t.Parallel()
2324
t.Run("Success", func(t *testing.T) {
2425
t.Parallel()
26+
ctx := testutil.Context(t, testutil.WaitShort)
2527
ptty := ptytest.New(t)
2628
msgChan := make(chan string)
2729
go func() {
28-
resp, err := newPrompt(ptty, cliui.PromptOptions{
30+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
2931
Text: "Example",
3032
}, nil)
3133
assert.NoError(t, err)
3234
msgChan <- resp
3335
}()
3436
ptty.ExpectMatch("Example")
3537
ptty.WriteLine("hello")
36-
require.Equal(t, "hello", <-msgChan)
38+
resp := testutil.RequireRecvCtx(ctx, t, msgChan)
39+
require.Equal(t, "hello", resp)
3740
})
3841

3942
t.Run("Confirm", func(t *testing.T) {
4043
t.Parallel()
44+
ctx := testutil.Context(t, testutil.WaitShort)
4145
ptty := ptytest.New(t)
4246
doneChan := make(chan string)
4347
go func() {
44-
resp, err := newPrompt(ptty, cliui.PromptOptions{
48+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
4549
Text: "Example",
4650
IsConfirm: true,
4751
}, nil)
@@ -50,18 +54,20 @@ func TestPrompt(t *testing.T) {
5054
}()
5155
ptty.ExpectMatch("Example")
5256
ptty.WriteLine("yes")
53-
require.Equal(t, "yes", <-doneChan)
57+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
58+
require.Equal(t, "yes", resp)
5459
})
5560

5661
t.Run("Skip", func(t *testing.T) {
5762
t.Parallel()
63+
ctx := testutil.Context(t, testutil.WaitShort)
5864
ptty := ptytest.New(t)
5965
var buf bytes.Buffer
6066

6167
// Copy all data written out to a buffer. When we close the ptty, we can
6268
// no longer read from the ptty.Output(), but we can read what was
6369
// written to the buffer.
64-
dataRead, doneReading := context.WithTimeout(context.Background(), testutil.WaitShort)
70+
dataRead, doneReading := context.WithCancel(ctx)
6571
go func() {
6672
// This will throw an error sometimes. The underlying ptty
6773
// has its own cleanup routines in t.Cleanup. Instead of
@@ -74,7 +80,7 @@ func TestPrompt(t *testing.T) {
7480

7581
doneChan := make(chan string)
7682
go func() {
77-
resp, err := newPrompt(ptty, cliui.PromptOptions{
83+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
7884
Text: "ShouldNotSeeThis",
7985
IsConfirm: true,
8086
}, func(inv *serpent.Invocation) {
@@ -85,7 +91,8 @@ func TestPrompt(t *testing.T) {
8591
doneChan <- resp
8692
}()
8793

88-
require.Equal(t, "yes", <-doneChan)
94+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
95+
require.Equal(t, "yes", resp)
8996
// Close the reader to end the io.Copy
9097
require.NoError(t, ptty.Close(), "close eof reader")
9198
// Wait for the IO copy to finish
@@ -96,42 +103,47 @@ func TestPrompt(t *testing.T) {
96103
})
97104
t.Run("JSON", func(t *testing.T) {
98105
t.Parallel()
106+
ctx := testutil.Context(t, testutil.WaitShort)
99107
ptty := ptytest.New(t)
100108
doneChan := make(chan string)
101109
go func() {
102-
resp, err := newPrompt(ptty, cliui.PromptOptions{
110+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
103111
Text: "Example",
104112
}, nil)
105113
assert.NoError(t, err)
106114
doneChan <- resp
107115
}()
108116
ptty.ExpectMatch("Example")
109117
ptty.WriteLine("{}")
110-
require.Equal(t, "{}", <-doneChan)
118+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
119+
require.Equal(t, "{}", resp)
111120
})
112121

113122
t.Run("BadJSON", func(t *testing.T) {
114123
t.Parallel()
124+
ctx := testutil.Context(t, testutil.WaitShort)
115125
ptty := ptytest.New(t)
116126
doneChan := make(chan string)
117127
go func() {
118-
resp, err := newPrompt(ptty, cliui.PromptOptions{
128+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
119129
Text: "Example",
120130
}, nil)
121131
assert.NoError(t, err)
122132
doneChan <- resp
123133
}()
124134
ptty.ExpectMatch("Example")
125135
ptty.WriteLine("{a")
126-
require.Equal(t, "{a", <-doneChan)
136+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
137+
require.Equal(t, "{a", resp)
127138
})
128139

129140
t.Run("MultilineJSON", func(t *testing.T) {
130141
t.Parallel()
142+
ctx := testutil.Context(t, testutil.WaitShort)
131143
ptty := ptytest.New(t)
132144
doneChan := make(chan string)
133145
go func() {
134-
resp, err := newPrompt(ptty, cliui.PromptOptions{
146+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
135147
Text: "Example",
136148
}, nil)
137149
assert.NoError(t, err)
@@ -141,11 +153,37 @@ func TestPrompt(t *testing.T) {
141153
ptty.WriteLine(`{
142154
"test": "wow"
143155
}`)
144-
require.Equal(t, `{"test":"wow"}`, <-doneChan)
156+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
157+
require.Equal(t, `{"test":"wow"}`, resp)
158+
})
159+
160+
t.Run("InvalidValid", func(t *testing.T) {
161+
t.Parallel()
162+
ctx := testutil.Context(t, testutil.WaitShort)
163+
ptty := ptytest.New(t)
164+
doneChan := make(chan string)
165+
go func() {
166+
resp, err := newPrompt(ctx, ptty, cliui.PromptOptions{
167+
Text: "Example",
168+
Validate: func(s string) error {
169+
t.Logf("validate: %q", s)
170+
if s != "valid" {
171+
return xerrors.New("invalid")
172+
}
173+
return nil
174+
},
175+
}, nil)
176+
assert.NoError(t, err)
177+
doneChan <- resp
178+
}()
179+
ptty.ExpectMatch("Example")
180+
ptty.WriteLine("foo\nbar\nbaz\n\n\nvalid\n")
181+
resp := testutil.RequireRecvCtx(ctx, t, doneChan)
182+
require.Equal(t, "valid", resp)
145183
})
146184
}
147185

148-
func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *serpent.Invocation)) (string, error) {
186+
func newPrompt(ctx context.Context, ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *serpent.Invocation)) (string, error) {
149187
value := ""
150188
cmd := &serpent.Command{
151189
Handler: func(inv *serpent.Invocation) error {
@@ -163,7 +201,7 @@ func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *ser
163201
inv.Stdout = ptty.Output()
164202
inv.Stderr = ptty.Output()
165203
inv.Stdin = ptty.Input()
166-
return value, inv.WithContext(context.Background()).Run()
204+
return value, inv.WithContext(ctx).Run()
167205
}
168206

169207
func TestPasswordTerminalState(t *testing.T) {

0 commit comments

Comments
 (0)