@@ -15,6 +15,7 @@ import (
15
15
"unicode/utf8"
16
16
17
17
"github.com/stretchr/testify/require"
18
+ "golang.org/x/exp/slices"
18
19
"golang.org/x/xerrors"
19
20
20
21
"github.com/coder/coder/pty"
@@ -143,151 +144,148 @@ func (p *PTY) ExpectMatch(str string) string {
143
144
timeout , cancel := context .WithTimeout (context .Background (), testutil .WaitMedium )
144
145
defer cancel ()
145
146
147
+ return p .ExpectMatchContext (timeout , str )
148
+ }
149
+
150
+ // TODO(mafredri): Rename this to ExpectMatch when refactoring.
151
+ func (p * PTY ) ExpectMatchContext (ctx context.Context , str string ) string {
152
+ p .t .Helper ()
153
+
146
154
var buffer bytes.Buffer
147
- match := make (chan error , 1 )
148
- go func () {
149
- defer close (match )
150
- match <- func () error {
151
- for {
152
- r , _ , err := p .runeReader .ReadRune ()
153
- if err != nil {
154
- return err
155
- }
156
- _ , err = buffer .WriteRune (r )
157
- if err != nil {
158
- return err
159
- }
160
- if strings .Contains (buffer .String (), str ) {
161
- return nil
162
- }
155
+ err := p .doMatchWithDeadline (ctx , "ExpectMatchContext" , func () error {
156
+ for {
157
+ r , _ , err := p .runeReader .ReadRune ()
158
+ if err != nil {
159
+ return err
160
+ }
161
+ _ , err = buffer .WriteRune (r )
162
+ if err != nil {
163
+ return err
164
+ }
165
+ if strings .Contains (buffer .String (), str ) {
166
+ return nil
163
167
}
164
- }()
165
- }()
166
-
167
- select {
168
- case err := <- match :
169
- if err != nil {
170
- p .fatalf ("read error" , "%v (wanted %q; got %q)" , err , str , buffer .String ())
171
- return ""
172
168
}
173
- p .logf ("matched %q = %q" , str , buffer .String ())
174
- return buffer .String ()
175
- case <- timeout .Done ():
176
- // Ensure goroutine is cleaned up before test exit.
177
- _ = p .close ("expect match timeout" )
178
- <- match
179
-
180
- p .fatalf ("match exceeded deadline" , "wanted %q; got %q" , str , buffer .String ())
169
+ })
170
+ if err != nil {
171
+ p .fatalf ("read error" , "%v (wanted %q; got %q)" , err , str , buffer .String ())
181
172
return ""
182
173
}
174
+ p .logf ("matched %q = %q" , str , buffer .String ())
175
+ return buffer .String ()
183
176
}
184
177
185
- func (p * PTY ) ReadRune (ctx context.Context ) rune {
178
+ func (p * PTY ) Peek (ctx context.Context , n int ) [] byte {
186
179
p .t .Helper ()
187
180
188
- // A timeout is mandatory, caller can decide by passing a context
189
- // that times out.
190
- if _ , ok := ctx . Deadline (); ! ok {
191
- timeout := testutil . WaitMedium
192
- p . logf ( "ReadRune ctx has no deadline, using %s" , timeout )
193
- var cancel context. CancelFunc
194
- //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
195
- ctx , cancel = context . WithTimeout ( ctx , timeout )
196
- defer cancel ()
181
+ var out [] byte
182
+ err := p . doMatchWithDeadline ( ctx , "Peek" , func () error {
183
+ var err error
184
+ out , err = p . runeReader . Peek ( n )
185
+ return err
186
+ })
187
+ if err != nil {
188
+ p . fatalf ( "read error" , "%v (wanted %d bytes; got %d: %q)" , err , n , len ( out ), out )
189
+ return nil
197
190
}
191
+ p .logf ("peeked %d/%d bytes = %q" , len (out ), n , out )
192
+ return slices .Clone (out )
193
+ }
194
+
195
+ func (p * PTY ) ReadRune (ctx context.Context ) rune {
196
+ p .t .Helper ()
198
197
199
198
var r rune
200
- match := make (chan error , 1 )
201
- go func () {
202
- defer close (match )
199
+ err := p .doMatchWithDeadline (ctx , "ReadRune" , func () error {
203
200
var err error
204
201
r , _ , err = p .runeReader .ReadRune ()
205
- match <- err
206
- }()
207
-
208
- select {
209
- case err := <- match :
210
- if err != nil {
211
- p .fatalf ("read error" , "%v (wanted newline; got %q)" , err , r )
212
- return 0
213
- }
214
- p .logf ("matched rune = %q" , r )
215
- return r
216
- case <- ctx .Done ():
217
- // Ensure goroutine is cleaned up before test exit.
218
- _ = p .close ("read rune context done: " + ctx .Err ().Error ())
219
- <- match
220
-
221
- p .fatalf ("read rune context done" , "wanted rune; got nothing" )
202
+ return err
203
+ })
204
+ if err != nil {
205
+ p .fatalf ("read error" , "%v (wanted rune; got %q)" , err , r )
222
206
return 0
223
207
}
208
+ p .logf ("matched rune = %q" , r )
209
+ return r
224
210
}
225
211
226
- func (p * PTY ) ReadLine () string {
212
+ func (p * PTY ) ReadLine (ctx context. Context ) string {
227
213
p .t .Helper ()
228
214
229
- // timeout, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
230
- timeout , cancel := context .WithCancel (context .Background ())
231
- defer cancel ()
232
-
233
215
var buffer bytes.Buffer
234
- match := make (chan error , 1 )
235
- go func () {
236
- defer close (match )
237
- match <- func () error {
238
- for {
239
- r , _ , err := p .runeReader .ReadRune ()
240
- if err != nil {
241
- return err
242
- }
243
- if r == '\n' {
216
+ err := p .doMatchWithDeadline (ctx , "ReadLine" , func () error {
217
+ for {
218
+ r , _ , err := p .runeReader .ReadRune ()
219
+ if err != nil {
220
+ return err
221
+ }
222
+ if r == '\n' {
223
+ return nil
224
+ }
225
+ if r == '\r' {
226
+ // Peek the next rune to see if it's an LF and then consume
227
+ // it.
228
+
229
+ // Unicode code points can be up to 4 bytes, but the
230
+ // ones we're looking for are only 1 byte.
231
+ b , _ := p .runeReader .Peek (1 )
232
+ if len (b ) == 0 {
244
233
return nil
245
234
}
246
- if r == '\r' {
247
- // Peek the next rune to see if it's an LF and then consume
248
- // it.
249
-
250
- // Unicode code points can be up to 4 bytes, but the
251
- // ones we're looking for are only 1 byte.
252
- b , _ := p .runeReader .Peek (1 )
253
- if len (b ) == 0 {
254
- return nil
255
- }
256
235
257
- r , _ = utf8 .DecodeRune (b )
258
- if r == '\n' {
259
- _ , _ , err = p .runeReader .ReadRune ()
260
- if err != nil {
261
- return err
262
- }
236
+ r , _ = utf8 .DecodeRune (b )
237
+ if r == '\n' {
238
+ _ , _ , err = p .runeReader .ReadRune ()
239
+ if err != nil {
240
+ return err
263
241
}
264
-
265
- return nil
266
242
}
267
243
268
- _ , err = buffer .WriteRune (r )
269
- if err != nil {
270
- return err
271
- }
244
+ return nil
272
245
}
273
- }()
274
- }()
275
246
247
+ _ , err = buffer .WriteRune (r )
248
+ if err != nil {
249
+ return err
250
+ }
251
+ }
252
+ })
253
+ if err != nil {
254
+ p .fatalf ("read error" , "%v (wanted newline; got %q)" , err , buffer .String ())
255
+ return ""
256
+ }
257
+ p .logf ("matched newline = %q" , buffer .String ())
258
+ return buffer .String ()
259
+ }
260
+
261
+ func (p * PTY ) doMatchWithDeadline (ctx context.Context , name string , fn func () error ) error {
262
+ p .t .Helper ()
263
+
264
+ // A timeout is mandatory, caller can decide by passing a context
265
+ // that times out.
266
+ if _ , ok := ctx .Deadline (); ! ok {
267
+ timeout := testutil .WaitMedium
268
+ p .logf ("%s ctx has no deadline, using %s" , name , timeout )
269
+ var cancel context.CancelFunc
270
+ //nolint:gocritic // Rule guard doesn't detect that we're using testutil.Wait*.
271
+ ctx , cancel = context .WithTimeout (ctx , timeout )
272
+ defer cancel ()
273
+ }
274
+
275
+ match := make (chan error , 1 )
276
+ go func () {
277
+ defer close (match )
278
+ match <- fn ()
279
+ }()
276
280
select {
277
281
case err := <- match :
278
- if err != nil {
279
- p .fatalf ("read error" , "%v (wanted newline; got %q)" , err , buffer .String ())
280
- return ""
281
- }
282
- p .logf ("matched newline = %q" , buffer .String ())
283
- return buffer .String ()
284
- case <- timeout .Done ():
282
+ return err
283
+ case <- ctx .Done ():
285
284
// Ensure goroutine is cleaned up before test exit.
286
- _ = p .close ("expect match timeout " )
285
+ _ = p .close ("match deadline exceeded " )
287
286
<- match
288
287
289
- p .fatalf ("match exceeded deadline" , "wanted newline; got %q" , buffer .String ())
290
- return ""
288
+ return xerrors .Errorf ("match deadline exceeded: %w" , ctx .Err ())
291
289
}
292
290
}
293
291
0 commit comments