8
8
9
9
"github.com/google/uuid"
10
10
"github.com/prometheus/client_golang/prometheus"
11
- "github.com/spf13/afero"
12
11
"github.com/stretchr/testify/require"
12
+ "go.uber.org/mock/gomock"
13
13
"golang.org/x/sync/errgroup"
14
14
15
15
"cdr.dev/slog/sloggers/slogtest"
@@ -18,6 +18,7 @@ import (
18
18
"github.com/coder/coder/v2/coderd/database"
19
19
"github.com/coder/coder/v2/coderd/database/dbauthz"
20
20
"github.com/coder/coder/v2/coderd/database/dbgen"
21
+ "github.com/coder/coder/v2/coderd/database/dbmock"
21
22
"github.com/coder/coder/v2/coderd/database/dbtestutil"
22
23
"github.com/coder/coder/v2/coderd/files"
23
24
"github.com/coder/coder/v2/coderd/rbac"
@@ -58,7 +59,7 @@ func TestCacheRBAC(t *testing.T) {
58
59
require .Equal (t , 0 , cache .Count ())
59
60
rec .Reset ()
60
61
61
- _ , err := cache .Acquire (nobody , file .ID )
62
+ _ , err := cache .Acquire (nobody , db , file .ID )
62
63
require .Error (t , err )
63
64
require .True (t , rbac .IsUnauthorizedError (err ))
64
65
@@ -75,18 +76,18 @@ func TestCacheRBAC(t *testing.T) {
75
76
require .Equal (t , 0 , cache .Count ())
76
77
77
78
// Read the file with a file reader to put it into the cache.
78
- a , err := cache .Acquire (cacheReader , file .ID )
79
+ a , err := cache .Acquire (cacheReader , db , file .ID )
79
80
require .NoError (t , err )
80
81
require .Equal (t , 1 , cache .Count ())
81
82
82
83
// "nobody" should not be able to read the file.
83
- _ , err = cache .Acquire (nobody , file .ID )
84
+ _ , err = cache .Acquire (nobody , db , file .ID )
84
85
require .Error (t , err )
85
86
require .True (t , rbac .IsUnauthorizedError (err ))
86
87
require .Equal (t , 1 , cache .Count ())
87
88
88
89
// UserReader can
89
- b , err := cache .Acquire (userReader , file .ID )
90
+ b , err := cache .Acquire (userReader , db , file .ID )
90
91
require .NoError (t , err )
91
92
require .Equal (t , 1 , cache .Count ())
92
93
@@ -110,16 +111,21 @@ func TestConcurrency(t *testing.T) {
110
111
ctx := dbauthz .AsFileReader (t .Context ())
111
112
112
113
const fileSize = 10
113
- emptyFS := afero .NewIOFS (afero .NewReadOnlyFs (afero .NewMemMapFs ()))
114
114
var fetches atomic.Int64
115
115
reg := prometheus .NewRegistry ()
116
- c := files .New (func (_ context.Context , _ uuid.UUID ) (files.CacheEntryValue , error ) {
116
+
117
+ dbM := dbmock .NewMockStore (gomock .NewController (t ))
118
+ dbM .EXPECT ().GetFileByID (gomock .Any (), gomock .Any ()).DoAndReturn (func (mTx context.Context , fileID uuid.UUID ) (database.File , error ) {
117
119
fetches .Add (1 )
118
- // Wait long enough before returning to make sure that all of the goroutines
120
+ // Wait long enough before returning to make sure that all the goroutines
119
121
// will be waiting in line, ensuring that no one duplicated a fetch.
120
122
time .Sleep (testutil .IntervalMedium )
121
- return files.CacheEntryValue {FS : emptyFS , Size : fileSize }, nil
122
- }, reg , & coderdtest.FakeAuthorizer {})
123
+ return database.File {
124
+ Data : make ([]byte , fileSize ),
125
+ }, nil
126
+ }).AnyTimes ()
127
+
128
+ c := files .New (reg , & coderdtest.FakeAuthorizer {})
123
129
124
130
batches := 1000
125
131
groups := make ([]* errgroup.Group , 0 , batches )
@@ -137,7 +143,7 @@ func TestConcurrency(t *testing.T) {
137
143
g .Go (func () error {
138
144
// We don't bother to Release these references because the Cache will be
139
145
// released at the end of the test anyway.
140
- _ , err := c .Acquire (ctx , id )
146
+ _ , err := c .Acquire (ctx , dbM , id )
141
147
return err
142
148
})
143
149
}
@@ -164,14 +170,15 @@ func TestRelease(t *testing.T) {
164
170
ctx := dbauthz .AsFileReader (t .Context ())
165
171
166
172
const fileSize = 10
167
- emptyFS := afero .NewIOFS (afero .NewReadOnlyFs (afero .NewMemMapFs ()))
168
173
reg := prometheus .NewRegistry ()
169
- c := files . New ( func ( _ context. Context , _ uuid. UUID ) (files. CacheEntryValue , error ) {
170
- return files. CacheEntryValue {
171
- FS : emptyFS ,
172
- Size : fileSize ,
174
+ dbM := dbmock . NewMockStore ( gomock . NewController ( t ))
175
+ dbM . EXPECT (). GetFileByID ( gomock . Any (), gomock . Any ()). DoAndReturn ( func ( mTx context. Context , fileID uuid. UUID ) (database. File , error ) {
176
+ return database. File {
177
+ Data : make ([] byte , fileSize ) ,
173
178
}, nil
174
- }, reg , & coderdtest.FakeAuthorizer {})
179
+ }).AnyTimes ()
180
+
181
+ c := files .New (reg , & coderdtest.FakeAuthorizer {})
175
182
176
183
batches := 100
177
184
ids := make ([]uuid.UUID , 0 , batches )
@@ -184,9 +191,8 @@ func TestRelease(t *testing.T) {
184
191
batchSize := 10
185
192
for openedIdx , id := range ids {
186
193
for batchIdx := range batchSize {
187
- it , err := c .Acquire (ctx , id )
194
+ it , err := c .Acquire (ctx , dbM , id )
188
195
require .NoError (t , err )
189
- require .Equal (t , emptyFS , it .FS )
190
196
releases [id ] = append (releases [id ], it .Close )
191
197
192
198
// Each time a new file is opened, the metrics should be updated as so:
@@ -257,7 +263,7 @@ func cacheAuthzSetup(t *testing.T) (database.Store, *files.Cache, *coderdtest.Re
257
263
258
264
// Dbauthz wrap the db
259
265
db = dbauthz .New (db , rec , logger , coderdtest .AccessControlStorePointer ())
260
- c := files .NewFromStore ( db , reg , rec )
266
+ c := files .New ( reg , rec )
261
267
return db , c , rec
262
268
}
263
269
0 commit comments