@@ -9,8 +9,10 @@ import (
9
9
"testing"
10
10
"time"
11
11
12
+ "github.com/google/uuid"
12
13
"github.com/stretchr/testify/assert"
13
14
"github.com/stretchr/testify/require"
15
+ "go.uber.org/mock/gomock"
14
16
"nhooyr.io/websocket"
15
17
"tailscale.com/tailcfg"
16
18
@@ -21,7 +23,7 @@ import (
21
23
"github.com/coder/coder/v2/codersdk"
22
24
"github.com/coder/coder/v2/codersdk/workspacesdk"
23
25
"github.com/coder/coder/v2/tailnet"
24
- "github.com/coder/coder/v2/tailnet/proto"
26
+ tailnetproto "github.com/coder/coder/v2/tailnet/proto"
25
27
"github.com/coder/coder/v2/tailnet/tailnettest"
26
28
"github.com/coder/coder/v2/testutil"
27
29
)
@@ -102,6 +104,7 @@ func TestWebsocketDialer_TokenController(t *testing.T) {
102
104
require .Equal (t , "" , gotToken )
103
105
104
106
clients = testutil .RequireRecvCtx (ctx , t , clientCh )
107
+ require .Nil (t , clients .WorkspaceUpdates )
105
108
clients .Closer .Close ()
106
109
107
110
err = testutil .RequireRecvCtx (ctx , t , wsErr )
@@ -273,7 +276,7 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
273
276
logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : true }).Leveled (slog .LevelDebug )
274
277
275
278
svr := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
276
- sVer := apiversion .New (proto . CurrentMajor , proto . CurrentMinor - 1 )
279
+ sVer := apiversion .New (2 , 2 )
277
280
278
281
// the following matches what Coderd does;
279
282
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
@@ -291,7 +294,10 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
291
294
svrURL , err := url .Parse (svr .URL )
292
295
require .NoError (t , err )
293
296
294
- uut := workspacesdk .NewWebsocketDialer (logger , svrURL , & websocket.DialOptions {})
297
+ uut := workspacesdk .NewWebsocketDialer (
298
+ logger , svrURL , & websocket.DialOptions {},
299
+ workspacesdk .WithWorkspaceUpdates (& tailnetproto.WorkspaceUpdatesRequest {}),
300
+ )
295
301
296
302
errCh := make (chan error , 1 )
297
303
go func () {
@@ -307,6 +313,84 @@ func TestWebsocketDialer_UplevelVersion(t *testing.T) {
307
313
require .NotEmpty (t , sdkErr .Helper )
308
314
}
309
315
316
+ func TestWebsocketDialer_WorkspaceUpdates (t * testing.T ) {
317
+ t .Parallel ()
318
+ ctx := testutil .Context (t , testutil .WaitShort )
319
+ logger := slogtest .Make (t , & slogtest.Options {
320
+ IgnoreErrors : true ,
321
+ }).Leveled (slog .LevelDebug )
322
+
323
+ fCoord := tailnettest .NewFakeCoordinator ()
324
+ var coord tailnet.Coordinator = fCoord
325
+ coordPtr := atomic.Pointer [tailnet.Coordinator ]{}
326
+ coordPtr .Store (& coord )
327
+ ctrl := gomock .NewController (t )
328
+ mProvider := tailnettest .NewMockWorkspaceUpdatesProvider (ctrl )
329
+
330
+ svc , err := tailnet .NewClientService (tailnet.ClientServiceOptions {
331
+ Logger : logger ,
332
+ CoordPtr : & coordPtr ,
333
+ DERPMapUpdateFrequency : time .Hour ,
334
+ DERPMapFn : func () * tailcfg.DERPMap { return & tailcfg.DERPMap {} },
335
+ WorkspaceUpdatesProvider : mProvider ,
336
+ })
337
+ require .NoError (t , err )
338
+
339
+ wsErr := make (chan error , 1 )
340
+ svr := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
341
+ // need 2.3 for WorkspaceUpdates RPC
342
+ cVer := r .URL .Query ().Get ("version" )
343
+ assert .Equal (t , "2.3" , cVer )
344
+
345
+ sws , err := websocket .Accept (w , r , nil )
346
+ if ! assert .NoError (t , err ) {
347
+ return
348
+ }
349
+ wsCtx , nc := codersdk .WebsocketNetConn (ctx , sws , websocket .MessageBinary )
350
+ // streamID can be empty because we don't call RPCs in this test.
351
+ wsErr <- svc .ServeConnV2 (wsCtx , nc , tailnet.StreamID {})
352
+ }))
353
+ defer svr .Close ()
354
+ svrURL , err := url .Parse (svr .URL )
355
+ require .NoError (t , err )
356
+
357
+ userID := uuid.UUID {88 }
358
+
359
+ mSub := tailnettest .NewMockSubscription (ctrl )
360
+ updateCh := make (chan * tailnetproto.WorkspaceUpdate , 1 )
361
+ mProvider .EXPECT ().Subscribe (gomock .Any (), userID ).Times (1 ).Return (mSub , nil )
362
+ mSub .EXPECT ().Updates ().MinTimes (1 ).Return (updateCh )
363
+ mSub .EXPECT ().Close ().Times (1 ).Return (nil )
364
+
365
+ uut := workspacesdk .NewWebsocketDialer (
366
+ logger , svrURL , & websocket.DialOptions {},
367
+ workspacesdk .WithWorkspaceUpdates (& tailnetproto.WorkspaceUpdatesRequest {
368
+ WorkspaceOwnerId : userID [:],
369
+ }),
370
+ )
371
+
372
+ clients , err := uut .Dial (ctx , nil )
373
+ require .NoError (t , err )
374
+ require .NotNil (t , clients .WorkspaceUpdates )
375
+
376
+ wsID := uuid.UUID {99 }
377
+ expectedUpdate := & tailnetproto.WorkspaceUpdate {
378
+ UpsertedWorkspaces : []* tailnetproto.Workspace {
379
+ {Id : wsID [:]},
380
+ },
381
+ }
382
+ updateCh <- expectedUpdate
383
+
384
+ gotUpdate , err := clients .WorkspaceUpdates .Recv ()
385
+ require .NoError (t , err )
386
+ require .Equal (t , wsID [:], gotUpdate .GetUpsertedWorkspaces ()[0 ].GetId ())
387
+
388
+ clients .Closer .Close ()
389
+
390
+ err = testutil .RequireRecvCtx (ctx , t , wsErr )
391
+ require .NoError (t , err )
392
+ }
393
+
310
394
type fakeResumeTokenController struct {
311
395
ctx context.Context
312
396
t testing.TB
0 commit comments