Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4dc08bc

Browse files
committed
feat: use Agent v2 API for Service Banner
1 parent 6f243f6 commit 4dc08bc

File tree

5 files changed

+170
-111
lines changed

5 files changed

+170
-111
lines changed

agent/agent.go

+31-20
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import (
4141
"github.com/coder/coder/v2/agent/agentproc"
4242
"github.com/coder/coder/v2/agent/agentscripts"
4343
"github.com/coder/coder/v2/agent/agentssh"
44+
"github.com/coder/coder/v2/agent/proto"
4445
"github.com/coder/coder/v2/agent/reconnectingpty"
4546
"github.com/coder/coder/v2/buildinfo"
4647
"github.com/coder/coder/v2/cli/gitauth"
@@ -95,7 +96,6 @@ type Client interface {
9596
PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error
9697
PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error
9798
PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error
98-
GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerConfig, error)
9999
}
100100

101101
type Agent interface {
@@ -269,7 +269,6 @@ func (a *agent) init(ctx context.Context) {
269269
func (a *agent) runLoop(ctx context.Context) {
270270
go a.reportLifecycleLoop(ctx)
271271
go a.reportMetadataLoop(ctx)
272-
go a.fetchServiceBannerLoop(ctx)
273272
go a.manageProcessPriorityLoop(ctx)
274273

275274
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
@@ -662,22 +661,23 @@ func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentL
662661
// fetchServiceBannerLoop fetches the service banner on an interval. It will
663662
// not be fetched immediately; the expectation is that it is primed elsewhere
664663
// (and must be done before the session actually starts).
665-
func (a *agent) fetchServiceBannerLoop(ctx context.Context) {
664+
func (a *agent) fetchServiceBannerLoop(ctx context.Context, aAPI proto.DRPCAgentClient) error {
666665
ticker := time.NewTicker(a.serviceBannerRefreshInterval)
667666
defer ticker.Stop()
668667
for {
669668
select {
670669
case <-ctx.Done():
671-
return
670+
return ctx.Err()
672671
case <-ticker.C:
673-
serviceBanner, err := a.client.GetServiceBanner(ctx)
672+
sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{})
674673
if err != nil {
675674
if ctx.Err() != nil {
676-
return
675+
return ctx.Err()
677676
}
678677
a.logger.Error(ctx, "failed to update service banner", slog.Error(err))
679-
continue
678+
return err
680679
}
680+
serviceBanner := proto.SDKServiceBannerFromProto(sbp)
681681
a.serviceBanner.Store(&serviceBanner)
682682
}
683683
}
@@ -693,10 +693,24 @@ func (a *agent) run(ctx context.Context) error {
693693
}
694694
a.sessionToken.Store(&sessionToken)
695695

696-
serviceBanner, err := a.client.GetServiceBanner(ctx)
696+
// Listen returns the dRPC connection we use for the Agent v2+ API
697+
conn, err := a.client.Listen(ctx)
698+
if err != nil {
699+
return err
700+
}
701+
defer func() {
702+
cErr := conn.Close()
703+
if cErr != nil {
704+
a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err))
705+
}
706+
}()
707+
708+
aAPI := proto.NewDRPCAgentClient(conn)
709+
sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{})
697710
if err != nil {
698711
return xerrors.Errorf("fetch service banner: %w", err)
699712
}
713+
serviceBanner := proto.SDKServiceBannerFromProto(sbp)
700714
a.serviceBanner.Store(&serviceBanner)
701715

702716
manifest, err := a.client.Manifest(ctx)
@@ -821,18 +835,6 @@ func (a *agent) run(ctx context.Context) error {
821835
network.SetBlockEndpoints(manifest.DisableDirectConnections)
822836
}
823837

824-
// Listen returns the dRPC connection we use for both Coordinator and DERPMap updates
825-
conn, err := a.client.Listen(ctx)
826-
if err != nil {
827-
return err
828-
}
829-
defer func() {
830-
cErr := conn.Close()
831-
if cErr != nil {
832-
a.logger.Debug(ctx, "error closing drpc connection", slog.Error(err))
833-
}
834-
}()
835-
836838
eg, egCtx := errgroup.WithContext(ctx)
837839
eg.Go(func() error {
838840
a.logger.Debug(egCtx, "running tailnet connection coordinator")
@@ -852,6 +854,15 @@ func (a *agent) run(ctx context.Context) error {
852854
return nil
853855
})
854856

857+
eg.Go(func() error {
858+
a.logger.Debug(egCtx, "running fetch server banner loop")
859+
err := a.fetchServiceBannerLoop(egCtx, aAPI)
860+
if err != nil {
861+
return xerrors.Errorf("fetch server banner loop: %w", err)
862+
}
863+
return nil
864+
})
865+
855866
return eg.Wait()
856867
}
857868

agent/agenttest/client.go

+87-25
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"tailscale.com/tailcfg"
1919

2020
"cdr.dev/slog"
21+
agentproto "github.com/coder/coder/v2/agent/proto"
2122
"github.com/coder/coder/v2/codersdk"
2223
"github.com/coder/coder/v2/codersdk/agentsdk"
2324
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
@@ -48,6 +49,9 @@ func NewClient(t testing.TB,
4849
}
4950
err := proto.DRPCRegisterTailnet(mux, drpcService)
5051
require.NoError(t, err)
52+
fakeAAPI := NewFakeAgentAPI(t, logger)
53+
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
54+
require.NoError(t, err)
5155
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
5256
Log: func(err error) {
5357
if xerrors.Is(err, io.EOF) {
@@ -64,22 +68,23 @@ func NewClient(t testing.TB,
6468
statsChan: statsChan,
6569
coordinator: coordinator,
6670
server: server,
71+
fakeAgentAPI: fakeAAPI,
6772
derpMapUpdates: derpMapUpdates,
6873
}
6974
}
7075

7176
type Client struct {
72-
t testing.TB
73-
logger slog.Logger
74-
agentID uuid.UUID
75-
manifest agentsdk.Manifest
76-
metadata map[string]agentsdk.Metadata
77-
statsChan chan *agentsdk.Stats
78-
coordinator tailnet.Coordinator
79-
server *drpcserver.Server
80-
LastWorkspaceAgent func()
81-
PatchWorkspaceLogs func() error
82-
GetServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
77+
t testing.TB
78+
logger slog.Logger
79+
agentID uuid.UUID
80+
manifest agentsdk.Manifest
81+
metadata map[string]agentsdk.Metadata
82+
statsChan chan *agentsdk.Stats
83+
coordinator tailnet.Coordinator
84+
server *drpcserver.Server
85+
fakeAgentAPI *FakeAgentAPI
86+
LastWorkspaceAgent func()
87+
PatchWorkspaceLogs func() error
8388

8489
mu sync.Mutex // Protects following.
8590
lifecycleStates []codersdk.WorkspaceAgentLifecycle
@@ -221,20 +226,7 @@ func (c *Client) PatchLogs(ctx context.Context, logs agentsdk.PatchLogs) error {
221226
}
222227

223228
func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) {
224-
c.mu.Lock()
225-
defer c.mu.Unlock()
226-
227-
c.GetServiceBannerFunc = f
228-
}
229-
230-
func (c *Client) GetServiceBanner(ctx context.Context) (codersdk.ServiceBannerConfig, error) {
231-
c.mu.Lock()
232-
defer c.mu.Unlock()
233-
c.logger.Debug(ctx, "get service banner")
234-
if c.GetServiceBannerFunc != nil {
235-
return c.GetServiceBannerFunc()
236-
}
237-
return codersdk.ServiceBannerConfig{}, nil
229+
c.fakeAgentAPI.SetServiceBannerFunc(f)
238230
}
239231

240232
func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error {
@@ -254,3 +246,73 @@ type closeFunc func() error
254246
func (c closeFunc) Close() error {
255247
return c()
256248
}
249+
250+
type FakeAgentAPI struct {
251+
sync.Mutex
252+
t testing.TB
253+
logger slog.Logger
254+
255+
getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
256+
}
257+
258+
func (*FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
259+
// TODO implement me
260+
panic("implement me")
261+
}
262+
263+
func (f *FakeAgentAPI) SetServiceBannerFunc(fn func() (codersdk.ServiceBannerConfig, error)) {
264+
f.Lock()
265+
defer f.Unlock()
266+
f.getServiceBannerFunc = fn
267+
f.logger.Info(context.Background(), "updated ServiceBannerFunc")
268+
}
269+
270+
func (f *FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceBannerRequest) (*agentproto.ServiceBanner, error) {
271+
f.Lock()
272+
defer f.Unlock()
273+
if f.getServiceBannerFunc == nil {
274+
return &agentproto.ServiceBanner{}, nil
275+
}
276+
sb, err := f.getServiceBannerFunc()
277+
if err != nil {
278+
return nil, err
279+
}
280+
return agentproto.ServiceBannerFromSDK(sb), nil
281+
}
282+
283+
func (*FakeAgentAPI) UpdateStats(context.Context, *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
284+
// TODO implement me
285+
panic("implement me")
286+
}
287+
288+
func (*FakeAgentAPI) UpdateLifecycle(context.Context, *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) {
289+
// TODO implement me
290+
panic("implement me")
291+
}
292+
293+
func (*FakeAgentAPI) BatchUpdateAppHealths(context.Context, *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
294+
// TODO implement me
295+
panic("implement me")
296+
}
297+
298+
func (*FakeAgentAPI) UpdateStartup(context.Context, *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
299+
// TODO implement me
300+
panic("implement me")
301+
}
302+
303+
func (*FakeAgentAPI) BatchUpdateMetadata(context.Context, *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
304+
// TODO implement me
305+
panic("implement me")
306+
}
307+
308+
func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) {
309+
// TODO implement me
310+
panic("implement me")
311+
}
312+
313+
func NewFakeAgentAPI(t testing.TB, logger slog.Logger) *FakeAgentAPI {
314+
return &FakeAgentAPI{
315+
t: t,
316+
logger: logger.Named("FakeAgentAPI"),
317+
}
318+
}

coderd/wsconncache/wsconncache_test.go

+48-35
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import (
2929
"cdr.dev/slog"
3030
"cdr.dev/slog/sloggers/slogtest"
3131
"github.com/coder/coder/v2/agent"
32+
"github.com/coder/coder/v2/agent/agenttest"
33+
agentproto "github.com/coder/coder/v2/agent/proto"
3234
"github.com/coder/coder/v2/coderd/wsconncache"
3335
"github.com/coder/coder/v2/codersdk"
3436
"github.com/coder/coder/v2/codersdk/agentsdk"
@@ -171,13 +173,12 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
171173
_ = coordinator.Close()
172174
})
173175
manifest.AgentID = uuid.New()
174-
aC := &client{
175-
t: t,
176-
agentID: manifest.AgentID,
177-
manifest: manifest,
178-
coordinator: coordinator,
179-
derpMapUpdates: make(chan *tailcfg.DERPMap),
180-
}
176+
aC := newClient(
177+
t,
178+
slogtest.Make(t, nil).Leveled(slog.LevelDebug),
179+
manifest,
180+
coordinator,
181+
)
181182
t.Cleanup(aC.close)
182183
closer := agent.New(agent.Options{
183184
Client: aC,
@@ -239,46 +240,62 @@ type client struct {
239240
coordinator tailnet.Coordinator
240241
closeOnce sync.Once
241242
derpMapUpdates chan *tailcfg.DERPMap
243+
server *drpcserver.Server
244+
fakeAgentAPI *agenttest.FakeAgentAPI
242245
}
243246

244-
func (c *client) close() {
245-
c.closeOnce.Do(func() { close(c.derpMapUpdates) })
246-
}
247-
248-
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
249-
return c.manifest, nil
250-
}
251-
252-
func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
253-
logger := slogtest.Make(c.t, nil).Leveled(slog.LevelDebug).Named("drpc")
254-
conn, lis := drpcsdk.MemTransportPipe()
255-
c.t.Cleanup(func() {
256-
_ = conn.Close()
257-
_ = lis.Close()
258-
})
247+
func newClient(t *testing.T, logger slog.Logger, manifest agentsdk.Manifest, coordinator tailnet.Coordinator) *client {
248+
logger = logger.Named("drpc")
259249
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
260-
coordPtr.Store(&c.coordinator)
250+
coordPtr.Store(&coordinator)
261251
mux := drpcmux.New()
252+
derpMapUpdates := make(chan *tailcfg.DERPMap)
262253
drpcService := &tailnet.DRPCService{
263254
CoordPtr: &coordPtr,
264255
Logger: logger,
265256
DerpMapUpdateFrequency: time.Microsecond,
266-
DerpMapFn: func() *tailcfg.DERPMap { return <-c.derpMapUpdates },
257+
DerpMapFn: func() *tailcfg.DERPMap { return <-derpMapUpdates },
267258
}
268259
err := proto.DRPCRegisterTailnet(mux, drpcService)
269-
if err != nil {
270-
return nil, xerrors.Errorf("register DRPC service: %w", err)
271-
}
260+
require.NoError(t, err)
261+
fakeAAPI := agenttest.NewFakeAgentAPI(t, logger)
262+
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
263+
require.NoError(t, err)
272264
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
273265
Log: func(err error) {
274-
if xerrors.Is(err, io.EOF) ||
275-
xerrors.Is(err, context.Canceled) ||
276-
xerrors.Is(err, context.DeadlineExceeded) {
266+
if xerrors.Is(err, io.EOF) {
277267
return
278268
}
279269
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
280270
},
281271
})
272+
273+
return &client{
274+
t: t,
275+
agentID: manifest.AgentID,
276+
manifest: manifest,
277+
coordinator: coordinator,
278+
derpMapUpdates: derpMapUpdates,
279+
server: server,
280+
fakeAgentAPI: fakeAAPI,
281+
}
282+
}
283+
284+
func (c *client) close() {
285+
c.closeOnce.Do(func() { close(c.derpMapUpdates) })
286+
}
287+
288+
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
289+
return c.manifest, nil
290+
}
291+
292+
func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
293+
conn, lis := drpcsdk.MemTransportPipe()
294+
c.t.Cleanup(func() {
295+
_ = conn.Close()
296+
_ = lis.Close()
297+
})
298+
282299
serveCtx, cancel := context.WithCancel(context.Background())
283300
c.t.Cleanup(cancel)
284301
auth := tailnet.AgentTunnelAuth{}
@@ -289,7 +306,7 @@ func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
289306
}
290307
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
291308
go func() {
292-
server.Serve(serveCtx, lis)
309+
c.server.Serve(serveCtx, lis)
293310
}()
294311
return conn, nil
295312
}
@@ -317,7 +334,3 @@ func (*client) PostStartup(_ context.Context, _ agentsdk.PostStartupRequest) err
317334
func (*client) PatchLogs(_ context.Context, _ agentsdk.PatchLogs) error {
318335
return nil
319336
}
320-
321-
func (*client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) {
322-
return codersdk.ServiceBannerConfig{}, nil
323-
}

0 commit comments

Comments
 (0)