Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a04548d

Browse files
refactor: convert workspacesdk.AgentConn to an interface
We convert `workspacesdk.AgentConn` to an interface and generate a mock for it. This allows writing `coderd` tests that rely on the agent's HTTP api to not have to set up an entire tailnet networking stack.
1 parent a8c89a1 commit a04548d

File tree

18 files changed

+647
-143
lines changed

18 files changed

+647
-143
lines changed

Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,8 @@ GEN_FILES := \
636636
coderd/database/pubsub/psmock/psmock.go \
637637
agent/agentcontainers/acmock/acmock.go \
638638
agent/agentcontainers/dcspec/dcspec_gen.go \
639-
coderd/httpmw/loggermw/loggermock/loggermock.go
639+
coderd/httpmw/loggermw/loggermock/loggermock.go \
640+
codersdk/workspacesdk/agentconnmock/agentconnmock.go
640641

641642
# all gen targets should be added here and to gen/mark-fresh
642643
gen: gen/db gen/golden-files $(GEN_FILES)
@@ -686,6 +687,7 @@ gen/mark-fresh:
686687
agent/agentcontainers/acmock/acmock.go \
687688
agent/agentcontainers/dcspec/dcspec_gen.go \
688689
coderd/httpmw/loggermw/loggermock/loggermock.go \
690+
codersdk/workspacesdk/agentconnmock/agentconnmock.go \
689691
"
690692

691693
for file in $$files; do
@@ -729,6 +731,10 @@ coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.g
729731
go generate ./coderd/httpmw/loggermw/loggermock/
730732
touch "$@"
731733

734+
codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agentconn.go
735+
go generate ./codersdk/workspacesdk/agentconnmock/
736+
touch "$@"
737+
732738
agent/agentcontainers/dcspec/dcspec_gen.go: \
733739
node_modules/.installed \
734740
agent/agentcontainers/dcspec/devContainer.base.schema.json \

agent/agent_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2739,9 +2739,9 @@ func TestAgent_Dial(t *testing.T) {
27392739

27402740
switch l.Addr().Network() {
27412741
case "tcp":
2742-
conn, err = agentConn.Conn.DialContextTCP(ctx, ipp)
2742+
conn, err = agentConn.TailnetConn().DialContextTCP(ctx, ipp)
27432743
case "udp":
2744-
conn, err = agentConn.Conn.DialContextUDP(ctx, ipp)
2744+
conn, err = agentConn.TailnetConn().DialContextUDP(ctx, ipp)
27452745
default:
27462746
t.Fatalf("unknown network: %s", l.Addr().Network())
27472747
}
@@ -2799,7 +2799,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
27992799
})
28002800

28012801
// Setup a client connection.
2802-
newClientConn := func(derpMap *tailcfg.DERPMap, name string) *workspacesdk.AgentConn {
2802+
newClientConn := func(derpMap *tailcfg.DERPMap, name string) workspacesdk.AgentConn {
28032803
conn, err := tailnet.NewConn(&tailnet.Options{
28042804
Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()},
28052805
DERPMap: derpMap,
@@ -2879,13 +2879,13 @@ func TestAgent_UpdatedDERP(t *testing.T) {
28792879

28802880
// Connect from a second client and make sure it uses the new DERP map.
28812881
conn2 := newClientConn(newDerpMap, "client2")
2882-
require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs())
2882+
require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs())
28832883
t.Log("conn2 got the new DERPMap")
28842884

28852885
// If the first client gets a DERP map update, it should be able to
28862886
// reconnect just fine.
2887-
conn1.SetDERPMap(newDerpMap)
2888-
require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs())
2887+
conn1.TailnetConn().SetDERPMap(newDerpMap)
2888+
require.Equal(t, []int{2}, conn1.TailnetConn().DERPMap().RegionIDs())
28892889
t.Log("set the new DERPMap on conn1")
28902890
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
28912891
defer cancel()
@@ -3252,7 +3252,7 @@ func setupSSHSessionOnPort(
32523252
}
32533253

32543254
func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
3255-
*workspacesdk.AgentConn,
3255+
workspacesdk.AgentConn,
32563256
*agenttest.Client,
32573257
<-chan *proto.Stats,
32583258
afero.Fs,

cli/ping.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func (r *RootCmd) ping() *serpent.Command {
147147
}
148148
defer conn.Close()
149149

150-
derpMap := conn.DERPMap()
150+
derpMap := conn.TailnetConn().DERPMap()
151151

152152
diagCtx, diagCancel := context.WithTimeout(inv.Context(), 30*time.Second)
153153
defer diagCancel()
@@ -156,7 +156,7 @@ func (r *RootCmd) ping() *serpent.Command {
156156
// Silent ping to determine whether we should show diags
157157
_, didP2p, _, _ := conn.Ping(ctx)
158158

159-
ni := conn.GetNetInfo()
159+
ni := conn.TailnetConn().GetNetInfo()
160160
connDiags := cliui.ConnDiags{
161161
DisableDirect: r.disableDirect,
162162
LocalNetInfo: ni,

cli/portforward.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func (r *RootCmd) portForward() *serpent.Command {
221221
func listenAndPortForward(
222222
ctx context.Context,
223223
inv *serpent.Invocation,
224-
conn *workspacesdk.AgentConn,
224+
conn workspacesdk.AgentConn,
225225
wg *sync.WaitGroup,
226226
spec portForwardSpec,
227227
logger slog.Logger,

cli/speedtest.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ func (r *RootCmd) speedtest() *serpent.Command {
139139
if err != nil {
140140
continue
141141
}
142-
status := conn.Status()
142+
status := conn.TailnetConn().Status()
143143
if len(status.Peers()) != 1 {
144144
continue
145145
}
@@ -189,7 +189,7 @@ func (r *RootCmd) speedtest() *serpent.Command {
189189
outputResult.Intervals[i] = interval
190190
}
191191
}
192-
conn.Conn.SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits)
192+
conn.TailnetConn().SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits)
193193
out, err := formatter.Format(inv.Context(), outputResult)
194194
if err != nil {
195195
return err

cli/ssh.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command {
590590
}
591591

592592
err = sshSession.Wait()
593-
conn.SendDisconnectedTelemetry()
593+
conn.TailnetConn().SendDisconnectedTelemetry()
594594
if err != nil {
595595
if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) {
596596
// Clear the error since it's not useful beyond
@@ -1364,7 +1364,7 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName {
13641364

13651365
func setStatsCallback(
13661366
ctx context.Context,
1367-
agentConn *workspacesdk.AgentConn,
1367+
agentConn workspacesdk.AgentConn,
13681368
logger slog.Logger,
13691369
networkInfoDir string,
13701370
networkInfoInterval time.Duration,
@@ -1437,7 +1437,7 @@ func setStatsCallback(
14371437

14381438
now := time.Now()
14391439
cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{})
1440-
agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb)
1440+
agentConn.TailnetConn().SetConnStatsCallback(networkInfoInterval, 2048, cb)
14411441
return errCh, nil
14421442
}
14431443

@@ -1451,13 +1451,13 @@ type sshNetworkStats struct {
14511451
UsingCoderConnect bool `json:"using_coder_connect"`
14521452
}
14531453

1454-
func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
1454+
func collectNetworkStats(ctx context.Context, agentConn workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
14551455
latency, p2p, pingResult, err := agentConn.Ping(ctx)
14561456
if err != nil {
14571457
return nil, err
14581458
}
1459-
node := agentConn.Node()
1460-
derpMap := agentConn.DERPMap()
1459+
node := agentConn.TailnetConn().Node()
1460+
derpMap := agentConn.TailnetConn().DERPMap()
14611461

14621462
totalRx := uint64(0)
14631463
totalTx := uint64(0)

coderd/coderd.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ func New(options *Options) *API {
325325
})
326326
}
327327

328+
if options.PrometheusRegistry == nil {
329+
options.PrometheusRegistry = prometheus.NewRegistry()
330+
}
328331
if options.Authorizer == nil {
329332
options.Authorizer = rbac.NewCachingAuthorizer(options.PrometheusRegistry)
330333
if buildinfo.IsDev() {
@@ -381,9 +384,6 @@ func New(options *Options) *API {
381384
if options.FilesRateLimit == 0 {
382385
options.FilesRateLimit = 12
383386
}
384-
if options.PrometheusRegistry == nil {
385-
options.PrometheusRegistry = prometheus.NewRegistry()
386-
}
387387
if options.Clock == nil {
388388
options.Clock = quartz.NewReal()
389389
}

coderd/tailnet.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,9 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (
277277
}, nil
278278
}
279279

280-
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) {
280+
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
281281
var (
282-
conn *workspacesdk.AgentConn
282+
conn workspacesdk.AgentConn
283283
ret func()
284284
)
285285

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package coderd
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"database/sql"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"net/http/httputil"
12+
"net/url"
13+
"strings"
14+
"testing"
15+
16+
"cdr.dev/slog"
17+
"cdr.dev/slog/sloggers/slogtest"
18+
"github.com/coder/coder/v2/coderd/database"
19+
"github.com/coder/coder/v2/coderd/database/dbmock"
20+
"github.com/coder/coder/v2/coderd/database/dbtime"
21+
"github.com/coder/coder/v2/coderd/httpmw"
22+
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
23+
"github.com/coder/coder/v2/codersdk"
24+
"github.com/coder/coder/v2/codersdk/workspacesdk"
25+
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
26+
"github.com/coder/coder/v2/codersdk/wsjson"
27+
"github.com/coder/coder/v2/tailnet"
28+
"github.com/coder/coder/v2/tailnet/tailnettest"
29+
"github.com/coder/coder/v2/testutil"
30+
"github.com/coder/websocket"
31+
"github.com/go-chi/chi/v5"
32+
"github.com/google/uuid"
33+
"github.com/stretchr/testify/require"
34+
"go.uber.org/mock/gomock"
35+
)
36+
37+
type fakeAgentProvider struct {
38+
agentConn func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error)
39+
}
40+
41+
func (fakeAgentProvider) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy {
42+
panic("unimplemented")
43+
}
44+
45+
func (f fakeAgentProvider) AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) {
46+
if f.agentConn != nil {
47+
return f.agentConn(ctx, agentID)
48+
}
49+
50+
panic("unimplemented")
51+
}
52+
53+
func (fakeAgentProvider) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
54+
panic("unimplemented")
55+
}
56+
57+
func (fakeAgentProvider) Close() error {
58+
return nil
59+
}
60+
61+
func TestWatchAgentContainers(t *testing.T) {
62+
t.Parallel()
63+
64+
t.Run("WebSocketClosesProperly", func(t *testing.T) {
65+
t.Parallel()
66+
67+
var (
68+
ctx = testutil.Context(t, testutil.WaitShort)
69+
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd")
70+
71+
mCtrl = gomock.NewController(t)
72+
mDB = dbmock.NewMockStore(mCtrl)
73+
mCoordinator = tailnettest.NewMockCoordinator(mCtrl)
74+
mAgentConn = agentconnmock.NewMockAgentConn(mCtrl)
75+
76+
fAgentProvider = fakeAgentProvider{
77+
agentConn: func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) {
78+
return mAgentConn, func() {}, nil
79+
},
80+
}
81+
82+
workspaceID = uuid.New()
83+
agentID = uuid.New()
84+
resourceID = uuid.New()
85+
jobID = uuid.New()
86+
buildID = uuid.New()
87+
88+
containersCh = make(chan codersdk.WorkspaceAgentListContainersResponse)
89+
90+
r = chi.NewMux()
91+
92+
api = API{
93+
ctx: ctx,
94+
Options: &Options{
95+
AgentInactiveDisconnectTimeout: testutil.WaitShort,
96+
Database: mDB,
97+
Logger: logger,
98+
DeploymentValues: &codersdk.DeploymentValues{},
99+
TailnetCoordinator: tailnettest.NewFakeCoordinator(),
100+
},
101+
}
102+
)
103+
104+
var tailnetCoordinator tailnet.Coordinator = mCoordinator
105+
api.TailnetCoordinator.Store(&tailnetCoordinator)
106+
api.agentProvider = fAgentProvider
107+
108+
// Setup: Allow `ExtractWorkspaceAgentParams` to complete.
109+
mDB.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(database.WorkspaceAgent{
110+
ID: agentID,
111+
ResourceID: resourceID,
112+
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
113+
FirstConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
114+
LastConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
115+
}, nil)
116+
mDB.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).Return(database.WorkspaceResource{
117+
ID: resourceID,
118+
JobID: jobID,
119+
}, nil)
120+
mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(database.ProvisionerJob{
121+
ID: jobID,
122+
Type: database.ProvisionerJobTypeWorkspaceBuild,
123+
}, nil)
124+
mDB.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), jobID).Return(database.WorkspaceBuild{
125+
WorkspaceID: workspaceID,
126+
ID: buildID,
127+
}, nil)
128+
129+
// And: Allow `db2dsk.WorkspaceAgent` to complete.
130+
mCoordinator.EXPECT().Node(gomock.Any()).Return(nil)
131+
132+
// And: Allow `WatchContainers` to be called.
133+
mAgentConn.EXPECT().WatchContainers(gomock.Any(), gomock.Any()).
134+
Return(containersCh, io.NopCloser(&bytes.Buffer{}), nil)
135+
136+
// And: We mount the HTTP Handler
137+
r.With(httpmw.ExtractWorkspaceAgentParam(mDB)).
138+
Get("/workspaceagents/{workspaceagent}/containers/watch", api.watchWorkspaceAgentContainers)
139+
140+
// Given: We create the HTTP server
141+
srv := httptest.NewServer(r)
142+
defer srv.Close()
143+
144+
// And: Dial the WebSocket
145+
wsURL := strings.Replace(srv.URL, "http://", "ws://", 1)
146+
conn, _, err := websocket.Dial(ctx, fmt.Sprintf("%s/workspaceagents/%s/containers/watch", wsURL, agentID), nil)
147+
require.NoError(t, err)
148+
149+
// And: Create a streaming decoder
150+
decoder := wsjson.NewDecoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText, logger)
151+
defer decoder.Close()
152+
decodeCh := decoder.Chan()
153+
154+
// When: We close the `containersCh`
155+
close(containersCh)
156+
157+
// Then: We expect `decodeCh` to be closed.
158+
select {
159+
case <-ctx.Done():
160+
t.Fail()
161+
162+
case _, ok := <-decodeCh:
163+
require.False(t, ok, "channel is expected to be closed")
164+
}
165+
})
166+
}

0 commit comments

Comments
 (0)