1
2
3
4
5 package tls
6
7 import (
8 "context"
9 "errors"
10 "fmt"
11 )
12
13
14
15 type QUICEncryptionLevel int
16
17 const (
18 QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
19 QUICEncryptionLevelEarly
20 QUICEncryptionLevelHandshake
21 QUICEncryptionLevelApplication
22 )
23
24 func (l QUICEncryptionLevel) String() string {
25 switch l {
26 case QUICEncryptionLevelInitial:
27 return "Initial"
28 case QUICEncryptionLevelEarly:
29 return "Early"
30 case QUICEncryptionLevelHandshake:
31 return "Handshake"
32 case QUICEncryptionLevelApplication:
33 return "Application"
34 default:
35 return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
36 }
37 }
38
39
40
41
42
43 type QUICConn struct {
44 conn *Conn
45
46 sessionTicketSent bool
47 }
48
49
50 type QUICConfig struct {
51 TLSConfig *Config
52
53
54
55
56
57
58 EnableSessionEvents bool
59 }
60
61
62 type QUICEventKind int
63
64 const (
65
66 QUICNoEvent QUICEventKind = iota
67
68
69
70
71
72
73
74 QUICSetReadSecret
75 QUICSetWriteSecret
76
77
78
79 QUICWriteData
80
81
82
83 QUICTransportParameters
84
85
86
87
88
89
90
91
92 QUICTransportParametersRequired
93
94
95
96
97
98 QUICRejectedEarlyData
99
100
101 QUICHandshakeDone
102
103
104
105
106
107
108
109
110
111 QUICResumeSession
112
113
114
115
116
117
118
119 QUICStoreSession
120 )
121
122
123
124
125
126 type QUICEvent struct {
127 Kind QUICEventKind
128
129
130 Level QUICEncryptionLevel
131
132
133
134 Data []byte
135
136
137 Suite uint16
138
139
140 SessionState *SessionState
141 }
142
143 type quicState struct {
144 events []QUICEvent
145 nextEvent int
146
147
148
149
150
151 eventArr [8]QUICEvent
152
153 started bool
154 signalc chan struct{}
155 blockedc chan struct{}
156 cancelc <-chan struct{}
157 cancel context.CancelFunc
158
159 waitingForDrain bool
160
161
162
163
164 readbuf []byte
165
166 transportParams []byte
167
168 enableSessionEvents bool
169 }
170
171
172
173
174
175 func QUICClient(config *QUICConfig) *QUICConn {
176 return newQUICConn(Client(nil, config.TLSConfig), config)
177 }
178
179
180
181
182
183 func QUICServer(config *QUICConfig) *QUICConn {
184 return newQUICConn(Server(nil, config.TLSConfig), config)
185 }
186
187 func newQUICConn(conn *Conn, config *QUICConfig) *QUICConn {
188 conn.quic = &quicState{
189 signalc: make(chan struct{}),
190 blockedc: make(chan struct{}),
191 enableSessionEvents: config.EnableSessionEvents,
192 }
193 conn.quic.events = conn.quic.eventArr[:0]
194 return &QUICConn{
195 conn: conn,
196 }
197 }
198
199
200
201
202
203 func (q *QUICConn) Start(ctx context.Context) error {
204 if q.conn.quic.started {
205 return quicError(errors.New("tls: Start called more than once"))
206 }
207 q.conn.quic.started = true
208 if q.conn.config.MinVersion < VersionTLS13 {
209 return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.3"))
210 }
211 go q.conn.HandshakeContext(ctx)
212 if _, ok := <-q.conn.quic.blockedc; !ok {
213 return q.conn.handshakeErr
214 }
215 return nil
216 }
217
218
219
220 func (q *QUICConn) NextEvent() QUICEvent {
221 qs := q.conn.quic
222 if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
223
224
225 qs.events[last].Data[0] = 0
226 }
227 if qs.nextEvent >= len(qs.events) && qs.waitingForDrain {
228 qs.waitingForDrain = false
229 <-qs.signalc
230 <-qs.blockedc
231 }
232 if qs.nextEvent >= len(qs.events) {
233 qs.events = qs.events[:0]
234 qs.nextEvent = 0
235 return QUICEvent{Kind: QUICNoEvent}
236 }
237 e := qs.events[qs.nextEvent]
238 qs.events[qs.nextEvent] = QUICEvent{}
239 qs.nextEvent++
240 return e
241 }
242
243
244 func (q *QUICConn) Close() error {
245 if q.conn.quic.cancel == nil {
246 return nil
247 }
248 q.conn.quic.cancel()
249 for range q.conn.quic.blockedc {
250
251 }
252 return q.conn.handshakeErr
253 }
254
255
256
257 func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
258 c := q.conn
259 if c.in.level != level {
260 return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
261 }
262 c.quic.readbuf = data
263 <-c.quic.signalc
264 _, ok := <-c.quic.blockedc
265 if ok {
266
267 return nil
268 }
269
270 c.handshakeMutex.Lock()
271 defer c.handshakeMutex.Unlock()
272 c.hand.Write(c.quic.readbuf)
273 c.quic.readbuf = nil
274 for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
275 b := q.conn.hand.Bytes()
276 n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
277 if n > maxHandshake {
278 q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
279 break
280 }
281 if len(b) < 4+n {
282 return nil
283 }
284 if err := q.conn.handlePostHandshakeMessage(); err != nil {
285 q.conn.handshakeErr = err
286 }
287 }
288 if q.conn.handshakeErr != nil {
289 return quicError(q.conn.handshakeErr)
290 }
291 return nil
292 }
293
294 type QUICSessionTicketOptions struct {
295
296 EarlyData bool
297 Extra [][]byte
298 }
299
300
301
302
303 func (q *QUICConn) SendSessionTicket(opts QUICSessionTicketOptions) error {
304 c := q.conn
305 if c.config.SessionTicketsDisabled {
306 return nil
307 }
308 if !c.isHandshakeComplete.Load() {
309 return quicError(errors.New("tls: SendSessionTicket called before handshake completed"))
310 }
311 if c.isClient {
312 return quicError(errors.New("tls: SendSessionTicket called on the client"))
313 }
314 if q.sessionTicketSent {
315 return quicError(errors.New("tls: SendSessionTicket called multiple times"))
316 }
317 q.sessionTicketSent = true
318 return quicError(c.sendSessionTicket(opts.EarlyData, opts.Extra))
319 }
320
321
322
323
324
325 func (q *QUICConn) StoreSession(session *SessionState) error {
326 c := q.conn
327 if !c.isClient {
328 return quicError(errors.New("tls: StoreSessionTicket called on the server"))
329 }
330 cacheKey := c.clientSessionCacheKey()
331 if cacheKey == "" {
332 return nil
333 }
334 cs := &ClientSessionState{session: session}
335 c.config.ClientSessionCache.Put(cacheKey, cs)
336 return nil
337 }
338
339
340 func (q *QUICConn) ConnectionState() ConnectionState {
341 return q.conn.ConnectionState()
342 }
343
344
345
346
347
348 func (q *QUICConn) SetTransportParameters(params []byte) {
349 if params == nil {
350 params = []byte{}
351 }
352 q.conn.quic.transportParams = params
353 if q.conn.quic.started {
354 <-q.conn.quic.signalc
355 <-q.conn.quic.blockedc
356 }
357 }
358
359
360
361 func quicError(err error) error {
362 if err == nil {
363 return nil
364 }
365 var ae AlertError
366 if errors.As(err, &ae) {
367 return err
368 }
369 var a alert
370 if !errors.As(err, &a) {
371 a = alertInternalError
372 }
373
374
375 return fmt.Errorf("%w%.0w", err, AlertError(a))
376 }
377
378 func (c *Conn) quicReadHandshakeBytes(n int) error {
379 for c.hand.Len() < n {
380 if err := c.quicWaitForSignal(); err != nil {
381 return err
382 }
383 }
384 return nil
385 }
386
387 func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
388 c.quic.events = append(c.quic.events, QUICEvent{
389 Kind: QUICSetReadSecret,
390 Level: level,
391 Suite: suite,
392 Data: secret,
393 })
394 }
395
396 func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
397 c.quic.events = append(c.quic.events, QUICEvent{
398 Kind: QUICSetWriteSecret,
399 Level: level,
400 Suite: suite,
401 Data: secret,
402 })
403 }
404
405 func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
406 var last *QUICEvent
407 if len(c.quic.events) > 0 {
408 last = &c.quic.events[len(c.quic.events)-1]
409 }
410 if last == nil || last.Kind != QUICWriteData || last.Level != level {
411 c.quic.events = append(c.quic.events, QUICEvent{
412 Kind: QUICWriteData,
413 Level: level,
414 })
415 last = &c.quic.events[len(c.quic.events)-1]
416 }
417 last.Data = append(last.Data, data...)
418 }
419
420 func (c *Conn) quicResumeSession(session *SessionState) error {
421 c.quic.events = append(c.quic.events, QUICEvent{
422 Kind: QUICResumeSession,
423 SessionState: session,
424 })
425 c.quic.waitingForDrain = true
426 for c.quic.waitingForDrain {
427 if err := c.quicWaitForSignal(); err != nil {
428 return err
429 }
430 }
431 return nil
432 }
433
434 func (c *Conn) quicStoreSession(session *SessionState) {
435 c.quic.events = append(c.quic.events, QUICEvent{
436 Kind: QUICStoreSession,
437 SessionState: session,
438 })
439 }
440
441 func (c *Conn) quicSetTransportParameters(params []byte) {
442 c.quic.events = append(c.quic.events, QUICEvent{
443 Kind: QUICTransportParameters,
444 Data: params,
445 })
446 }
447
448 func (c *Conn) quicGetTransportParameters() ([]byte, error) {
449 if c.quic.transportParams == nil {
450 c.quic.events = append(c.quic.events, QUICEvent{
451 Kind: QUICTransportParametersRequired,
452 })
453 }
454 for c.quic.transportParams == nil {
455 if err := c.quicWaitForSignal(); err != nil {
456 return nil, err
457 }
458 }
459 return c.quic.transportParams, nil
460 }
461
462 func (c *Conn) quicHandshakeComplete() {
463 c.quic.events = append(c.quic.events, QUICEvent{
464 Kind: QUICHandshakeDone,
465 })
466 }
467
468 func (c *Conn) quicRejectedEarlyData() {
469 c.quic.events = append(c.quic.events, QUICEvent{
470 Kind: QUICRejectedEarlyData,
471 })
472 }
473
474
475
476
477
478
479 func (c *Conn) quicWaitForSignal() error {
480
481
482 c.handshakeMutex.Unlock()
483 defer c.handshakeMutex.Lock()
484
485
486
487 select {
488 case c.quic.blockedc <- struct{}{}:
489 case <-c.quic.cancelc:
490 return c.sendAlertLocked(alertCloseNotify)
491 }
492
493
494
495 select {
496 case c.quic.signalc <- struct{}{}:
497 c.hand.Write(c.quic.readbuf)
498 c.quic.readbuf = nil
499 case <-c.quic.cancelc:
500 return c.sendAlertLocked(alertCloseNotify)
501 }
502 return nil
503 }
504
View as plain text