Thanks to visit codestin.com
Credit goes to github.com

Skip to content

test: Refactor ptytest to use contexts and less duplication #5740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions cli/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,15 @@ func TestServer(t *testing.T) {
})
t.Run("BuiltinPostgresURLRaw", func(t *testing.T) {
t.Parallel()
ctx, _ := testutil.Context(t)

root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url")
pty := ptytest.New(t)
root.SetOutput(pty.Output())
err := root.Execute()
err := root.ExecuteContext(ctx)
require.NoError(t, err)

got := pty.ReadLine()
got := pty.ReadLine(ctx)
if !strings.HasPrefix(got, "postgres://") {
t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got)
}
Expand Down Expand Up @@ -491,12 +493,12 @@ func TestServer(t *testing.T) {
// We can't use waitAccessURL as it will only return the HTTP URL.
const httpLinePrefix = "Started HTTP listener at "
pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine()
httpLine := pty.ReadLine(ctx)
httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr)
const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine()
tlsLine := pty.ReadLine(ctx)
tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr)

Expand Down Expand Up @@ -617,14 +619,14 @@ func TestServer(t *testing.T) {
if c.httpListener {
const httpLinePrefix = "Started HTTP listener at "
pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine()
httpLine := pty.ReadLine(ctx)
httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr)
}
if c.tlsListener {
const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine()
tlsLine := pty.ReadLine(ctx)
tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr)
}
Expand Down Expand Up @@ -1212,7 +1214,7 @@ func TestServer(t *testing.T) {

t.Run("Stackdriver", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc()

fi := testutil.TempFile(t, "", "coder-logging-test-*")
Expand Down Expand Up @@ -1240,10 +1242,9 @@ func TestServer(t *testing.T) {
<-serverErr
}()

require.Eventually(t, func() bool {
line := pty.ReadLine()
return strings.HasPrefix(line, "Started HTTP listener at ")
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
// Wait for server to listen on HTTP, this is a good
// starting point for expecting logs.
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")

require.Eventually(t, func() bool {
stat, err := os.Stat(fi)
Expand All @@ -1253,7 +1254,7 @@ func TestServer(t *testing.T) {

t.Run("Multiple", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc()

fi1 := testutil.TempFile(t, "", "coder-logging-test-*")
Expand Down Expand Up @@ -1289,10 +1290,9 @@ func TestServer(t *testing.T) {
<-serverErr
}()

require.Eventually(t, func() bool {
line := pty.ReadLine()
return strings.HasPrefix(line, "Started HTTP listener at ")
}, testutil.WaitLong*2, testutil.IntervalMedium, "wait for server to listen on http")
// Wait for server to listen on HTTP, this is a good
// starting point for expecting logs.
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ")

require.Eventually(t, func() bool {
stat, err := os.Stat(fi1)
Expand Down
2 changes: 1 addition & 1 deletion cli/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ Expire-Date: 0

// Wait for the prompt or any output really to indicate the command has
// started and accepting input on stdin.
_ = pty.ReadRune(ctx)
_ = pty.Peek(ctx, 1)

pty.WriteLine("echo hello 'world'")
pty.ExpectMatch("hello world")
Expand Down
220 changes: 109 additions & 111 deletions pty/ptytest/ptytest.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"unicode/utf8"

"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"

"github.com/coder/coder/pty"
Expand Down Expand Up @@ -143,151 +144,148 @@ func (p *PTY) ExpectMatch(str string) string {
timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()

return p.ExpectMatchContext(timeout, str)
}

// TODO(mafredri): Rename this to ExpectMatch when refactoring.
func (p *PTY) ExpectMatchContext(ctx context.Context, str string) string {
p.t.Helper()

var buffer bytes.Buffer
match := make(chan error, 1)
go func() {
defer close(match)
match <- func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
_, err = buffer.WriteRune(r)
if err != nil {
return err
}
if strings.Contains(buffer.String(), str) {
return nil
}
err := p.doMatchWithDeadline(ctx, "ExpectMatchContext", func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
_, err = buffer.WriteRune(r)
if err != nil {
return err
}
if strings.Contains(buffer.String(), str) {
return nil
}
}()
}()

select {
case err := <-match:
if err != nil {
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
return ""
}
p.logf("matched %q = %q", str, buffer.String())
return buffer.String()
case <-timeout.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("expect match timeout")
<-match

p.fatalf("match exceeded deadline", "wanted %q; got %q", str, buffer.String())
})
if err != nil {
p.fatalf("read error", "%v (wanted %q; got %q)", err, str, buffer.String())
return ""
}
p.logf("matched %q = %q", str, buffer.String())
return buffer.String()
}

func (p *PTY) ReadRune(ctx context.Context) rune {
func (p *PTY) Peek(ctx context.Context, n int) []byte {
p.t.Helper()

// A timeout is mandatory, caller can decide by passing a context
// that times out.
if _, ok := ctx.Deadline(); !ok {
timeout := testutil.WaitMedium
p.logf("ReadRune ctx has no deadline, using %s", timeout)
var cancel context.CancelFunc
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
var out []byte
err := p.doMatchWithDeadline(ctx, "Peek", func() error {
var err error
out, err = p.runeReader.Peek(n)
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted %d bytes; got %d: %q)", err, n, len(out), out)
return nil
}
p.logf("peeked %d/%d bytes = %q", len(out), n, out)
return slices.Clone(out)
}

func (p *PTY) ReadRune(ctx context.Context) rune {
p.t.Helper()

var r rune
match := make(chan error, 1)
go func() {
defer close(match)
err := p.doMatchWithDeadline(ctx, "ReadRune", func() error {
var err error
r, _, err = p.runeReader.ReadRune()
match <- err
}()

select {
case err := <-match:
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, r)
return 0
}
p.logf("matched rune = %q", r)
return r
case <-ctx.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("read rune context done: " + ctx.Err().Error())
<-match

p.fatalf("read rune context done", "wanted rune; got nothing")
return err
})
if err != nil {
p.fatalf("read error", "%v (wanted rune; got %q)", err, r)
return 0
}
p.logf("matched rune = %q", r)
return r
}

func (p *PTY) ReadLine() string {
func (p *PTY) ReadLine(ctx context.Context) string {
p.t.Helper()

// timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
timeout, cancel := context.WithCancel(context.Background())
defer cancel()

var buffer bytes.Buffer
match := make(chan error, 1)
go func() {
defer close(match)
match <- func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
if r == '\n' {
err := p.doMatchWithDeadline(ctx, "ReadLine", func() error {
for {
r, _, err := p.runeReader.ReadRune()
if err != nil {
return err
}
if r == '\n' {
return nil
}
if r == '\r' {
// Peek the next rune to see if it's an LF and then consume
// it.

// Unicode code points can be up to 4 bytes, but the
// ones we're looking for are only 1 byte.
b, _ := p.runeReader.Peek(1)
if len(b) == 0 {
return nil
}
if r == '\r' {
// Peek the next rune to see if it's an LF and then consume
// it.

// Unicode code points can be up to 4 bytes, but the
// ones we're looking for are only 1 byte.
b, _ := p.runeReader.Peek(1)
if len(b) == 0 {
return nil
}

r, _ = utf8.DecodeRune(b)
if r == '\n' {
_, _, err = p.runeReader.ReadRune()
if err != nil {
return err
}
r, _ = utf8.DecodeRune(b)
if r == '\n' {
_, _, err = p.runeReader.ReadRune()
if err != nil {
return err
}

return nil
}

_, err = buffer.WriteRune(r)
if err != nil {
return err
}
return nil
}
}()
}()

_, err = buffer.WriteRune(r)
if err != nil {
return err
}
}
})
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
return ""
}
p.logf("matched newline = %q", buffer.String())
return buffer.String()
}

func (p *PTY) doMatchWithDeadline(ctx context.Context, name string, fn func() error) error {
p.t.Helper()

// A timeout is mandatory, caller can decide by passing a context
// that times out.
if _, ok := ctx.Deadline(); !ok {
timeout := testutil.WaitMedium
p.logf("%s ctx has no deadline, using %s", name, timeout)
var cancel context.CancelFunc
//nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}

match := make(chan error, 1)
go func() {
defer close(match)
match <- fn()
}()
select {
case err := <-match:
if err != nil {
p.fatalf("read error", "%v (wanted newline; got %q)", err, buffer.String())
return ""
}
p.logf("matched newline = %q", buffer.String())
return buffer.String()
case <-timeout.Done():
return err
case <-ctx.Done():
// Ensure goroutine is cleaned up before test exit.
_ = p.close("expect match timeout")
_ = p.close("match deadline exceeded")
<-match

p.fatalf("match exceeded deadline", "wanted newline; got %q", buffer.String())
return ""
return xerrors.Errorf("match deadline exceeded: %w", ctx.Err())
}
}

Expand Down
Loading