@@ -15,6 +15,7 @@ import (
1515 "unicode/utf8"
1616
1717 "github.com/stretchr/testify/require"
18+ "golang.org/x/exp/slices"
1819 "golang.org/x/xerrors"
1920
2021 "github.com/coder/coder/pty"
@@ -143,151 +144,148 @@ func (p *PTY) ExpectMatch(str string) string {
143144 timeout , cancel := context .WithTimeout (context .Background (), testutil .WaitMedium )
144145 defer cancel ()
145146
147+ return p .ExpectMatchContext (timeout , str )
148+ }
149+
150+ // TODO(mafredri): Rename this to ExpectMatch when refactoring.
151+ func (p * PTY ) ExpectMatchContext (ctx context.Context , str string ) string {
152+ p .t .Helper ()
153+
146154 var buffer bytes.Buffer
147- match := make (chan error , 1 )
148- go func () {
149- defer close (match )
150- match <- func () error {
151- for {
152- r , _ , err := p .runeReader .ReadRune ()
153- if err != nil {
154- return err
155- }
156- _ , err = buffer .WriteRune (r )
157- if err != nil {
158- return err
159- }
160- if strings .Contains (buffer .String (), str ) {
161- return nil
162- }
155+ err := p .doMatchWithDeadline (ctx , "ExpectMatchContext" , func () error {
156+ for {
157+ r , _ , err := p .runeReader .ReadRune ()
158+ if err != nil {
159+ return err
160+ }
161+ _ , err = buffer .WriteRune (r )
162+ if err != nil {
163+ return err
164+ }
165+ if strings .Contains (buffer .String (), str ) {
166+ return nil
163167 }
164- }()
165- }()
166-
167- select {
168- case err := <- match :
169- if err != nil {
170- p .fatalf ("read error" , "%v (wanted %q; got %q)" , err , str , buffer .String ())
171- return ""
172168 }
173- p .logf ("matched %q = %q" , str , buffer .String ())
174- return buffer .String ()
175- case <- timeout .Done ():
176- // Ensure goroutine is cleaned up before test exit.
177- _ = p .close ("expect match timeout" )
178- <- match
179-
180- p .fatalf ("match exceeded deadline" , "wanted %q; got %q" , str , buffer .String ())
169+ })
170+ if err != nil {
171+ p .fatalf ("read error" , "%v (wanted %q; got %q)" , err , str , buffer .String ())
181172 return ""
182173 }
174+ p .logf ("matched %q = %q" , str , buffer .String ())
175+ return buffer .String ()
183176}
184177
185- func (p * PTY ) ReadRune (ctx context.Context ) rune {
178+ func (p * PTY ) Peek (ctx context.Context , n int ) [] byte {
186179 p .t .Helper ()
187180
188- // A timeout is mandatory, caller can decide by passing a context
189- // that times out.
190- if _ , ok := ctx . Deadline (); ! ok {
191- timeout := testutil . WaitMedium
192- p . logf ( "ReadRune ctx has no deadline, using %s" , timeout )
193- var cancel context. CancelFunc
194- //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
195- ctx , cancel = context . WithTimeout ( ctx , timeout )
196- defer cancel ()
181+ var out [] byte
182+ err := p . doMatchWithDeadline ( ctx , "Peek" , func () error {
183+ var err error
184+ out , err = p . runeReader . Peek ( n )
185+ return err
186+ })
187+ if err != nil {
188+ p . fatalf ( "read error" , "%v (wanted %d bytes; got %d: %q)" , err , n , len ( out ), out )
189+ return nil
197190 }
191+ p .logf ("peeked %d/%d bytes = %q" , len (out ), n , out )
192+ return slices .Clone (out )
193+ }
194+
195+ func (p * PTY ) ReadRune (ctx context.Context ) rune {
196+ p .t .Helper ()
198197
199198 var r rune
200- match := make (chan error , 1 )
201- go func () {
202- defer close (match )
199+ err := p .doMatchWithDeadline (ctx , "ReadRune" , func () error {
203200 var err error
204201 r , _ , err = p .runeReader .ReadRune ()
205- match <- err
206- }()
207-
208- select {
209- case err := <- match :
210- if err != nil {
211- p .fatalf ("read error" , "%v (wanted newline; got %q)" , err , r )
212- return 0
213- }
214- p .logf ("matched rune = %q" , r )
215- return r
216- case <- ctx .Done ():
217- // Ensure goroutine is cleaned up before test exit.
218- _ = p .close ("read rune context done: " + ctx .Err ().Error ())
219- <- match
220-
221- p .fatalf ("read rune context done" , "wanted rune; got nothing" )
202+ return err
203+ })
204+ if err != nil {
205+ p .fatalf ("read error" , "%v (wanted rune; got %q)" , err , r )
222206 return 0
223207 }
208+ p .logf ("matched rune = %q" , r )
209+ return r
224210}
225211
226- func (p * PTY ) ReadLine () string {
212+ func (p * PTY ) ReadLine (ctx context. Context ) string {
227213 p .t .Helper ()
228214
229- // timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
230- timeout , cancel := context .WithCancel (context .Background ())
231- defer cancel ()
232-
233215 var buffer bytes.Buffer
234- match := make (chan error , 1 )
235- go func () {
236- defer close (match )
237- match <- func () error {
238- for {
239- r , _ , err := p .runeReader .ReadRune ()
240- if err != nil {
241- return err
242- }
243- if r == '\n' {
216+ err := p .doMatchWithDeadline (ctx , "ReadLine" , func () error {
217+ for {
218+ r , _ , err := p .runeReader .ReadRune ()
219+ if err != nil {
220+ return err
221+ }
222+ if r == '\n' {
223+ return nil
224+ }
225+ if r == '\r' {
226+ // Peek the next rune to see if it's an LF and then consume
227+ // it.
228+
229+ // Unicode code points can be up to 4 bytes, but the
230+ // ones we're looking for are only 1 byte.
231+ b , _ := p .runeReader .Peek (1 )
232+ if len (b ) == 0 {
244233 return nil
245234 }
246- if r == '\r' {
247- // Peek the next rune to see if it's an LF and then consume
248- // it.
249-
250- // Unicode code points can be up to 4 bytes, but the
251- // ones we're looking for are only 1 byte.
252- b , _ := p .runeReader .Peek (1 )
253- if len (b ) == 0 {
254- return nil
255- }
256235
257- r , _ = utf8 .DecodeRune (b )
258- if r == '\n' {
259- _ , _ , err = p .runeReader .ReadRune ()
260- if err != nil {
261- return err
262- }
236+ r , _ = utf8 .DecodeRune (b )
237+ if r == '\n' {
238+ _ , _ , err = p .runeReader .ReadRune ()
239+ if err != nil {
240+ return err
263241 }
264-
265- return nil
266242 }
267243
268- _ , err = buffer .WriteRune (r )
269- if err != nil {
270- return err
271- }
244+ return nil
272245 }
273- }()
274- }()
275246
247+ _ , err = buffer .WriteRune (r )
248+ if err != nil {
249+ return err
250+ }
251+ }
252+ })
253+ if err != nil {
254+ p .fatalf ("read error" , "%v (wanted newline; got %q)" , err , buffer .String ())
255+ return ""
256+ }
257+ p .logf ("matched newline = %q" , buffer .String ())
258+ return buffer .String ()
259+ }
260+
261+ func (p * PTY ) doMatchWithDeadline (ctx context.Context , name string , fn func () error ) error {
262+ p .t .Helper ()
263+
264+ // A timeout is mandatory, caller can decide by passing a context
265+ // that times out.
266+ if _ , ok := ctx .Deadline (); ! ok {
267+ timeout := testutil .WaitMedium
268+ p .logf ("%s ctx has no deadline, using %s" , name , timeout )
269+ var cancel context.CancelFunc
270+ //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
271+ ctx , cancel = context .WithTimeout (ctx , timeout )
272+ defer cancel ()
273+ }
274+
275+ match := make (chan error , 1 )
276+ go func () {
277+ defer close (match )
278+ match <- fn ()
279+ }()
276280 select {
277281 case err := <- match :
278- if err != nil {
279- p .fatalf ("read error" , "%v (wanted newline; got %q)" , err , buffer .String ())
280- return ""
281- }
282- p .logf ("matched newline = %q" , buffer .String ())
283- return buffer .String ()
284- case <- timeout .Done ():
282+ return err
283+ case <- ctx .Done ():
285284 // Ensure goroutine is cleaned up before test exit.
286- _ = p .close ("expect match timeout " )
285+ _ = p .close ("match deadline exceeded " )
287286 <- match
288287
289- p .fatalf ("match exceeded deadline" , "wanted newline; got %q" , buffer .String ())
290- return ""
288+ return xerrors .Errorf ("match deadline exceeded: %w" , ctx .Err ())
291289 }
292290}
293291
0 commit comments