@@ -2,12 +2,14 @@ package files_test
2
2
3
3
import (
4
4
"context"
5
+ "sync"
5
6
"sync/atomic"
6
7
"testing"
7
8
"time"
8
9
9
10
"github.com/google/uuid"
10
11
"github.com/prometheus/client_golang/prometheus"
12
+ "github.com/stretchr/testify/assert"
11
13
"github.com/stretchr/testify/require"
12
14
"go.uber.org/mock/gomock"
13
15
"golang.org/x/sync/errgroup"
@@ -26,6 +28,63 @@ import (
26
28
"github.com/coder/coder/v2/testutil"
27
29
)
28
30
31
+ // TestCancelledFetch runs 2 Acquire calls. The first fails with a ctx.Canceled
32
+ // error. The second call should ignore the first error and try to fetch the file
33
+ // again, which should succeed.
34
+ func TestCancelledFetch (t * testing.T ) {
35
+ t .Parallel ()
36
+
37
+ fileID := uuid .New ()
38
+ rdy := make (chan struct {})
39
+ dbM := dbmock .NewMockStore (gomock .NewController (t ))
40
+
41
+ // First call should fail
42
+ dbM .EXPECT ().GetFileByID (gomock .Any (), gomock .Any ()).DoAndReturn (func (mTx context.Context , fileID uuid.UUID ) (database.File , error ) {
43
+ // Wait long enough for the second call to be queued up.
44
+ return database.File {}, context .Canceled
45
+ })
46
+
47
+ // Second call should succeed
48
+ dbM .EXPECT ().GetFileByID (gomock .Any (), gomock .Any ()).DoAndReturn (func (mTx context.Context , fileID uuid.UUID ) (database.File , error ) {
49
+ return database.File {
50
+ ID : fileID ,
51
+ Data : make ([]byte , 100 ),
52
+ }, nil
53
+ })
54
+
55
+ //nolint:gocritic // Unit testing
56
+ ctx := dbauthz .AsFileReader (testutil .Context (t , testutil .WaitShort ))
57
+ cache := files .New (prometheus .NewRegistry (), & coderdtest.FakeAuthorizer {})
58
+
59
+ var wg sync.WaitGroup
60
+
61
+ // First call that will fail
62
+ wg .Add (1 )
63
+ go func () {
64
+ close (rdy )
65
+ _ , err := cache .Acquire (ctx , dbM , fileID )
66
+ assert .ErrorIs (t , err , context .Canceled )
67
+ wg .Done ()
68
+ }()
69
+
70
+ // Second call, that should succeed
71
+ wg .Add (1 )
72
+ go func () {
73
+ // Wait until the first goroutine has started
74
+ <- rdy
75
+ fs , err := cache .Acquire (ctx , dbM , fileID )
76
+ assert .NoError (t , err )
77
+ if fs != nil {
78
+ fs .Close ()
79
+ }
80
+ wg .Done ()
81
+ }()
82
+
83
+ // We need that second Acquire call to be queued up
84
+ time .Sleep (testutil .IntervalFast )
85
+ wg .Wait ()
86
+ }
87
+
29
88
// nolint:paralleltest,tparallel // Serially testing is easier
30
89
func TestCacheRBAC (t * testing.T ) {
31
90
t .Parallel ()
0 commit comments