Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ffda8cd

Browse files
committed
fix: Rewrite ptytest to buffer stdout
Fixes #2122
1 parent 54547a4 commit ffda8cd

File tree

3 files changed

+206
-47
lines changed

3 files changed

+206
-47
lines changed

pty/ptytest/ptytest.go

Lines changed: 170 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,116 +7,247 @@ import (
77
"io"
88
"os"
99
"os/exec"
10-
"regexp"
1110
"runtime"
1211
"strings"
12+
"sync"
1313
"testing"
1414
"time"
1515
"unicode/utf8"
1616

1717
"github.com/stretchr/testify/require"
18+
"golang.org/x/xerrors"
1819

1920
"github.com/coder/coder/pty"
2021
)
2122

22-
var (
23-
// Used to ensure terminal output doesn't have anything crazy!
24-
// See: https://stackoverflow.com/a/29497680
25-
stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))")
26-
)
27-
2823
func New(t *testing.T) *PTY {
2924
ptty, err := pty.New()
3025
require.NoError(t, err)
3126

32-
return create(t, ptty)
27+
return create(t, ptty, "cmd")
3328
}
3429

3530
func Start(t *testing.T, cmd *exec.Cmd) (*PTY, *os.Process) {
3631
ptty, ps, err := pty.Start(cmd)
3732
require.NoError(t, err)
38-
return create(t, ptty), ps
33+
return create(t, ptty, cmd.Args[0]), ps
3934
}
4035

41-
func create(t *testing.T, ptty pty.PTY) *PTY {
42-
reader, writer := io.Pipe()
43-
scanner := bufio.NewScanner(reader)
36+
func create(t *testing.T, ptty pty.PTY, name string) *PTY {
37+
// Use pipe for logging.
38+
logDone := make(chan struct{})
39+
logr, logw := io.Pipe()
4440
t.Cleanup(func() {
45-
_ = reader.Close()
46-
_ = writer.Close()
41+
_ = logw.Close()
42+
_ = logr.Close()
43+
<-logDone // Guard against logging after test.
4744
})
4845
go func() {
49-
for scanner.Scan() {
50-
if scanner.Err() != nil {
51-
return
52-
}
53-
t.Log(stripAnsi.ReplaceAllString(scanner.Text(), ""))
46+
defer close(logDone)
47+
s := bufio.NewScanner(logr)
48+
for s.Scan() {
49+
// Quote output to avoid terminal escape codes, e.g. bell.
50+
t.Logf("%s: stdout: %q", name, s.Text())
5451
}
5552
}()
5653

54+
// Write to log and output buffer.
55+
copyDone := make(chan struct{})
56+
out := newStdbuf()
57+
w := io.MultiWriter(logw, out)
58+
go func() {
59+
defer close(copyDone)
60+
_, err := io.Copy(w, ptty.Output())
61+
_ = out.closeErr(err)
62+
}()
5763
t.Cleanup(func() {
64+
_ = out.Close
5865
_ = ptty.Close()
66+
<-copyDone
5967
})
68+
6069
return &PTY{
6170
t: t,
6271
PTY: ptty,
72+
out: out,
6373

64-
outputWriter: writer,
65-
runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax),
74+
runeReader: bufio.NewReaderSize(out, utf8.UTFMax),
6675
}
6776
}
6877

6978
type PTY struct {
7079
t *testing.T
7180
pty.PTY
81+
out *stdbuf
7282

73-
outputWriter io.Writer
74-
runeReader *bufio.Reader
83+
runeReader *bufio.Reader
7584
}
7685

7786
func (p *PTY) ExpectMatch(str string) string {
78-
var buffer bytes.Buffer
79-
multiWriter := io.MultiWriter(&buffer, p.outputWriter)
80-
runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax)
87+
p.t.Helper()
88+
8189
complete, cancelFunc := context.WithCancel(context.Background())
8290
defer cancelFunc()
91+
92+
timeout := make(chan error, 1)
8393
go func() {
94+
defer close(timeout)
8495
timer := time.NewTimer(10 * time.Second)
8596
defer timer.Stop()
8697
select {
8798
case <-complete.Done():
8899
return
89100
case <-timer.C:
90101
}
91-
_ = p.Close()
92-
p.t.Errorf("%s match exceeded deadline: wanted %q; got %q", time.Now(), str, buffer.String())
102+
timeout <- xerrors.Errorf("%s match exceeded deadline", time.Now())
93103
}()
94-
for {
95-
var r rune
96-
r, _, err := p.runeReader.ReadRune()
97-
require.NoError(p.t, err)
98-
_, err = runeWriter.WriteRune(r)
99-
require.NoError(p.t, err)
100-
err = runeWriter.Flush()
101-
require.NoError(p.t, err)
102-
if strings.Contains(buffer.String(), str) {
103-
break
104+
105+
var buffer bytes.Buffer
106+
match := make(chan error, 1)
107+
go func() {
108+
defer close(match)
109+
for {
110+
r, _, err := p.runeReader.ReadRune()
111+
if err != nil {
112+
match <- err
113+
return
114+
}
115+
_, err = buffer.WriteRune(r)
116+
if err != nil {
117+
match <- err
118+
return
119+
}
120+
if strings.Contains(buffer.String(), str) {
121+
match <- nil
122+
return
123+
}
124+
}
125+
}()
126+
127+
select {
128+
case err := <-match:
129+
if err != nil {
130+
p.t.Fatalf("read error: %v (wanted %q; got %q)", err, str, buffer.String())
104131
}
132+
p.t.Logf("matched %q = %q", str, buffer.String())
133+
case err := <-timeout:
134+
_ = p.out.closeErr(p.Close())
135+
p.t.Fatalf("%s: wanted %q; got %q", err, str, buffer.String())
105136
}
106-
p.t.Logf("matched %q = %q", str, stripAnsi.ReplaceAllString(buffer.String(), ""))
107137
return buffer.String()
108138
}
109139

110140
func (p *PTY) Write(r rune) {
141+
p.t.Helper()
142+
111143
_, err := p.Input().Write([]byte{byte(r)})
112144
require.NoError(p.t, err)
113145
}
114146

115147
func (p *PTY) WriteLine(str string) {
148+
p.t.Helper()
149+
116150
newline := []byte{'\r'}
117151
if runtime.GOOS == "windows" {
118152
newline = append(newline, '\n')
119153
}
120154
_, err := p.Input().Write(append([]byte(str), newline...))
121155
require.NoError(p.t, err)
122156
}
157+
158+
// stdbuf is like a buffered stdout, it buffers writes until read.
159+
type stdbuf struct {
160+
r io.Reader
161+
162+
mu sync.Mutex // Protects following.
163+
b []byte
164+
more chan struct{}
165+
err error
166+
}
167+
168+
func newStdbuf() *stdbuf {
169+
return &stdbuf{more: make(chan struct{}, 1)}
170+
}
171+
172+
func (b *stdbuf) Read(p []byte) (int, error) {
173+
if b.r == nil {
174+
return b.read(p)
175+
}
176+
177+
n, err := b.r.Read(p)
178+
if xerrors.Is(err, io.EOF) {
179+
b.r = nil
180+
err = nil
181+
if n == 0 {
182+
return b.read(p)
183+
}
184+
}
185+
return n, err
186+
}
187+
188+
func (b *stdbuf) read(p []byte) (int, error) {
189+
b.mu.Lock()
190+
defer b.mu.Unlock()
191+
192+
// Deplete channel so that more check
193+
// is for future input into buffer.
194+
select {
195+
case <-b.more:
196+
default:
197+
}
198+
199+
if len(b.b) == 0 {
200+
if b.err != nil {
201+
return 0, b.err
202+
}
203+
204+
b.mu.Unlock()
205+
<-b.more
206+
b.mu.Lock()
207+
}
208+
209+
b.r = bytes.NewReader(b.b)
210+
b.b = b.b[len(b.b):]
211+
212+
return b.r.Read(p)
213+
}
214+
215+
func (b *stdbuf) Write(p []byte) (int, error) {
216+
if len(p) == 0 {
217+
return 0, nil
218+
}
219+
220+
b.mu.Lock()
221+
defer b.mu.Unlock()
222+
223+
if b.err != nil {
224+
return 0, b.err
225+
}
226+
227+
b.b = append(b.b, p...)
228+
229+
select {
230+
case b.more <- struct{}{}:
231+
default:
232+
}
233+
234+
return len(p), nil
235+
}
236+
237+
func (b *stdbuf) Close() error {
238+
return b.closeErr(nil)
239+
}
240+
241+
func (b *stdbuf) closeErr(err error) error {
242+
b.mu.Lock()
243+
defer b.mu.Unlock()
244+
if b.err != nil {
245+
return err
246+
}
247+
if err == nil {
248+
err = io.EOF
249+
}
250+
b.err = err
251+
close(b.more)
252+
return b.err
253+
}

pty/ptytest/ptytest_internal_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package ptytest
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestStdbuf(t *testing.T) {
12+
t.Parallel()
13+
14+
var got bytes.Buffer
15+
16+
b := newStdbuf()
17+
done := make(chan struct{})
18+
go func() {
19+
defer close(done)
20+
io.Copy(&got, b)
21+
}()
22+
23+
b.Write([]byte("hello "))
24+
b.Write([]byte("world\n"))
25+
b.Write([]byte("bye\n"))
26+
27+
b.Close()
28+
<-done
29+
30+
assert.Equal(t, "hello world\nbye\n", got.String())
31+
}

pty/ptytest/ptytest_test.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package ptytest_test
22

33
import (
44
"fmt"
5-
"runtime"
65
"strings"
76
"testing"
87

@@ -22,26 +21,24 @@ func TestPtytest(t *testing.T) {
2221
pty.WriteLine("read")
2322
})
2423

24+
// See https://github.com/coder/coder/issues/2122 for the motivation
25+
// behind this test.
2526
t.Run("Cobra ptytest should not hang when output is not consumed", func(t *testing.T) {
2627
t.Parallel()
2728

2829
tests := []struct {
2930
name string
3031
output string
31-
isPlatformBug bool // See https://github.com/coder/coder/issues/2122 for more info.
32+
isPlatformBug bool
3233
}{
3334
{name: "1024 is safe (does not exceed macOS buffer)", output: strings.Repeat(".", 1024)},
34-
{name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025), isPlatformBug: true},
35-
{name: "10241 large output", output: strings.Repeat(".", 10241), isPlatformBug: true}, // 1024 * 10 + 1
35+
{name: "1025 exceeds macOS buffer (must not hang)", output: strings.Repeat(".", 1025)},
36+
{name: "10241 large output", output: strings.Repeat(".", 10241)}, // 1024 * 10 + 1
3637
}
3738
for _, tt := range tests {
3839
tt := tt
3940
// nolint:paralleltest // Avoid parallel test to more easily identify the issue.
4041
t.Run(tt.name, func(t *testing.T) {
41-
if tt.isPlatformBug && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") {
42-
t.Skip("This test hangs on macOS and Windows, see https://github.com/coder/coder/issues/2122")
43-
}
44-
4542
cmd := cobra.Command{
4643
Use: "test",
4744
RunE: func(cmd *cobra.Command, args []string) error {

0 commit comments

Comments
 (0)