Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 145d101

Browse files
authored
test: Refactor ptytest to use contexts and less duplication (#5740)
1 parent 77e71f3 commit 145d101

File tree

4 files changed

+132
-132
lines changed

4 files changed

+132
-132
lines changed

cli/server_test.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,15 @@ func TestServer(t *testing.T) {
120120
})
121121
t.Run("BuiltinPostgresURLRaw", func(t *testing.T) {
122122
t.Parallel()
123+
ctx, _ := testutil.Context(t)
124+
123125
root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url")
124126
pty := ptytest.New(t)
125127
root.SetOutput(pty.Output())
126-
err := root.Execute()
128+
err := root.ExecuteContext(ctx)
127129
require.NoError(t, err)
128130

129-
got := pty.ReadLine()
131+
got := pty.ReadLine(ctx)
130132
if !strings.HasPrefix(got, "postgres://") {
131133
t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got)
132134
}
@@ -491,12 +493,12 @@ func TestServer(t *testing.T) {
491493
// We can't use waitAccessURL as it will only return the HTTP URL.
492494
const httpLinePrefix = "Started HTTP listener at "
493495
pty.ExpectMatch(httpLinePrefix)
494-
httpLine := pty.ReadLine()
496+
httpLine := pty.ReadLine(ctx)
495497
httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
496498
require.NotEmpty(t, httpAddr)
497499
const tlsLinePrefix = "Started TLS/HTTPS listener at "
498500
pty.ExpectMatch(tlsLinePrefix)
499-
tlsLine := pty.ReadLine()
501+
tlsLine := pty.ReadLine(ctx)
500502
tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
501503
require.NotEmpty(t, tlsAddr)
502504

@@ -617,14 +619,14 @@ func TestServer(t *testing.T) {
617619
if c.httpListener {
618620
const httpLinePrefix = "Started HTTP listener at "
619621
pty.ExpectMatch(httpLinePrefix)
620-
httpLine := pty.ReadLine()
622+
httpLine := pty.ReadLine(ctx)
621623
httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
622624
require.NotEmpty(t, httpAddr)
623625
}
624626
if c.tlsListener {
625627
const tlsLinePrefix = "Started TLS/HTTPS listener at "
626628
pty.ExpectMatch(tlsLinePrefix)
627-
tlsLine := pty.ReadLine()
629+
tlsLine := pty.ReadLine(ctx)
628630
tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
629631
require.NotEmpty(t, tlsAddr)
630632
}
@@ -1212,7 +1214,7 @@ func TestServer(t *testing.T) {
12121214

12131215
t.Run("Stackdriver", func(t *testing.T) {
12141216
t.Parallel()
1215-
ctx, cancelFunc := context.WithCancel(context.Background())
1217+
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
12161218
defer cancelFunc()
12171219

12181220
fi := testutil.TempFile(t, "", "coder-logging-test-*")
@@ -1240,10 +1242,9 @@ func TestServer(t *testing.T) {
12401242
<-serverErr
12411243
}()
12421244

1243-
require.Eventually(t, func() bool {
1244-
line := pty.ReadLine()
1245-
return strings.HasPrefix(line, "Started HTTP listener at ")
1246-
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
1245+
// Wait for server to listen on HTTP, this is a good
1246+
// starting point for expecting logs.
1247+
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
12471248

12481249
require.Eventually(t, func() bool {
12491250
stat, err := os.Stat(fi)
@@ -1253,7 +1254,7 @@ func TestServer(t *testing.T) {
12531254

12541255
t.Run("Multiple", func(t *testing.T) {
12551256
t.Parallel()
1256-
ctx, cancelFunc := context.WithCancel(context.Background())
1257+
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
12571258
defer cancelFunc()
12581259

12591260
fi1 := testutil.TempFile(t, "", "coder-logging-test-*")
@@ -1289,10 +1290,9 @@ func TestServer(t *testing.T) {
12891290
<-serverErr
12901291
}()
12911292

1292-
require.Eventually(t, func() bool {
1293-
line := pty.ReadLine()
1294-
return strings.HasPrefix(line, "Started HTTP listener at ")
1295-
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
1293+
// Wait for server to listen on HTTP, this is a good
1294+
// starting point for expecting logs.
1295+
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")
12961296

12971297
require.Eventually(t, func() bool {
12981298
stat, err := os.Stat(fi1)

cli/ssh_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ Expire-Date: 0
477477

478478
// Wait for the prompt or any output really to indicate the command has
479479
// started and accepting input on stdin.
480-
_ = pty.ReadRune(ctx)
480+
_ = pty.Peek(ctx, 1)
481481

482482
pty.WriteLine("echo hello 'world'")
483483
pty.ExpectMatch("hello world")

pty/ptytest/ptytest.go

Lines changed: 109 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"unicode/utf8"
1616

1717
"github.com/stretchr/testify/require"
18+
"golang.org/x/exp/slices"
1819
"golang.org/x/xerrors"
1920

2021
"github.com/coder/coder/pty"
@@ -143,151 +144,148 @@ func (p *PTY) ExpectMatch(str string) string {
143144
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
144145
defer cancel()
145146

147+
return p.ExpectMatchContext(timeout, str)
148+
}
149+
150+
// TODO(mafredri): Rename this to ExpectMatch when refactoring.
151+
func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
152+
p.t.Helper()
153+
146154
var buffer bytes.Buffer
147-
match := make(chan error, 1)
148-
go func() {
149-
defer close(match)
150-
match <- func() error {
151-
for {
152-
r, _, err := p.runeReader.ReadRune()
153-
if err != nil {
154-
return err
155-
}
156-
_, err = buffer.WriteRune(r)
157-
if err != nil {
158-
return err
159-
}
160-
if strings.Contains(buffer.String(), str) {
161-
return nil
162-
}
155+
err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
156+
for {
157+
r, _, err := p.runeReader.ReadRune()
158+
if err != nil {
159+
return err
160+
}
161+
_, err = buffer.WriteRune(r)
162+
if err != nil {
163+
return err
164+
}
165+
if strings.Contains(buffer.String(), str) {
166+
return nil
163167
}
164-
}()
165-
}()
166-
167-
select {
168-
case err := <-match:
169-
if err != nil {
170-
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
171-
return ""
172168
}
173-
p.logf("matched %q = %q", str, buffer.String())
174-
return buffer.String()
175-
case <-timeout.Done():
176-
// Ensure goroutine is cleaned up before test exit.
177-
_ = p.close("expect match timeout")
178-
<-match
179-
180-
p.fatalf("match exceeded deadline", "wanted %q; got %q", str, buffer.String())
169+
})
170+
if err != nil {
171+
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
181172
return ""
182173
}
174+
p.logf("matched %q = %q", str, buffer.String())
175+
return buffer.String()
183176
}
184177

185-
func (p *PTY) ReadRune(ctx context.Context) rune {
178+
func (p *PTY) Peek(ctx context.Context, n int) []byte {
186179
p.t.Helper()
187180

188-
// A timeout is mandatory, caller can decide by passing a context
189-
// that times out.
190-
if _, ok := ctx.Deadline(); !ok {
191-
timeout := testutil.WaitMedium
192-
p.logf("ReadRune ctx has no deadline, using %s", timeout)
193-
var cancel context.CancelFunc
194-
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
195-
ctx, cancel = context.WithTimeout(ctx, timeout)
196-
defer cancel()
181+
var out []byte
182+
err := p.doMatchWithDeadline(ctx, "Peek", func() error {
183+
var err error
184+
out, err = p.runeReader.Peek(n)
185+
return err
186+
})
187+
if err != nil {
188+
p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
189+
return nil
197190
}
191+
p.logf("peeked %d/%d bytes = %q", len(out), n, out)
192+
return slices.Clone(out)
193+
}
194+
195+
func (p *PTY) ReadRune(ctx context.Context) rune {
196+
p.t.Helper()
198197

199198
var r rune
200-
match := make(chan error, 1)
201-
go func() {
202-
defer close(match)
199+
err := p.doMatchWithDeadline(ctx, "ReadRune", func() error {
203200
var err error
204201
r, _, err = p.runeReader.ReadRune()
205-
match <- err
206-
}()
207-
208-
select {
209-
case err := <-match:
210-
if err != nil {
211-
p.fatalf("read error", "%v (wanted newline; got %q)", err, r)
212-
return 0
213-
}
214-
p.logf("matched rune = %q", r)
215-
return r
216-
case <-ctx.Done():
217-
// Ensure goroutine is cleaned up before test exit.
218-
_ = p.close("read rune context done: " + ctx.Err().Error())
219-
<-match
220-
221-
p.fatalf("read rune context done", "wanted rune; got nothing")
202+
return err
203+
})
204+
if err != nil {
205+
p.fatalf("read error", "%v (wanted rune; got %q)", err, r)
222206
return 0
223207
}
208+
p.logf("matched rune = %q", r)
209+
return r
224210
}
225211

226-
func (p *PTY) ReadLine() string {
212+
func (p *PTY) ReadLine(ctx context.Context) string {
227213
p.t.Helper()
228214

229-
// timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
230-
timeout, cancel := context.WithCancel(context.Background())
231-
defer cancel()
232-
233215
var buffer bytes.Buffer
234-
match := make(chan error, 1)
235-
go func() {
236-
defer close(match)
237-
match <- func() error {
238-
for {
239-
r, _, err := p.runeReader.ReadRune()
240-
if err != nil {
241-
return err
242-
}
243-
if r == '\n' {
216+
err := p.doMatchWithDeadline(ctx, "ReadLine", func() error {
217+
for {
218+
r, _, err := p.runeReader.ReadRune()
219+
if err != nil {
220+
return err
221+
}
222+
if r == '\n' {
223+
return nil
224+
}
225+
if r == '\r' {
226+
// Peek the next rune to see if it's an LF and then consume
227+
// it.
228+
229+
// Unicode code points can be up to 4 bytes, but the
230+
// ones we're looking for are only 1 byte.
231+
b, _ := p.runeReader.Peek(1)
232+
if len(b) == 0 {
244233
return nil
245234
}
246-
if r == '\r' {
247-
// Peek the next rune to see if it's an LF and then consume
248-
// it.
249-
250-
// Unicode code points can be up to 4 bytes, but the
251-
// ones we're looking for are only 1 byte.
252-
b, _ := p.runeReader.Peek(1)
253-
if len(b) == 0 {
254-
return nil
255-
}
256235

257-
r, _ = utf8.DecodeRune(b)
258-
if r == '\n' {
259-
_, _, err = p.runeReader.ReadRune()
260-
if err != nil {
261-
return err
262-
}
236+
r, _ = utf8.DecodeRune(b)
237+
if r == '\n' {
238+
_, _, err = p.runeReader.ReadRune()
239+
if err != nil {
240+
return err
263241
}
264-
265-
return nil
266242
}
267243

268-
_, err = buffer.WriteRune(r)
269-
if err != nil {
270-
return err
271-
}
244+
return nil
272245
}
273-
}()
274-
}()
275246

247+
_, err = buffer.WriteRune(r)
248+
if err != nil {
249+
return err
250+
}
251+
}
252+
})
253+
if err != nil {
254+
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
255+
return ""
256+
}
257+
p.logf("matched newline = %q", buffer.String())
258+
return buffer.String()
259+
}
260+
261+
func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
262+
p.t.Helper()
263+
264+
// A timeout is mandatory, caller can decide by passing a context
265+
// that times out.
266+
if _, ok := ctx.Deadline(); !ok {
267+
timeout := testutil.WaitMedium
268+
p.logf("%s ctx has no deadline, using %s", name, timeout)
269+
var cancel context.CancelFunc
270+
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
271+
ctx, cancel = context.WithTimeout(ctx, timeout)
272+
defer cancel()
273+
}
274+
275+
match := make(chan error, 1)
276+
go func() {
277+
defer close(match)
278+
match <- fn()
279+
}()
276280
select {
277281
case err := <-match:
278-
if err != nil {
279-
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
280-
return ""
281-
}
282-
p.logf("matched newline = %q", buffer.String())
283-
return buffer.String()
284-
case <-timeout.Done():
282+
return err
283+
case <-ctx.Done():
285284
// Ensure goroutine is cleaned up before test exit.
286-
_ = p.close("expect match timeout")
285+
_ = p.close("match deadline exceeded")
287286
<-match
288287

289-
p.fatalf("match exceeded deadline", "wanted newline; got %q", buffer.String())
290-
return ""
288+
return xerrors.Errorf("match deadline exceeded: %w", ctx.Err())
291289
}
292290
}
293291

0 commit comments

Comments
 (0)