Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
expose SetupComplete() method on agent.agent
  • Loading branch information
dwahler committed Jul 14, 2022
commit 1e7552609d7709f332452fc6b9a23aaee429ea65
32 changes: 26 additions & 6 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ const (
ProtocolDial = "dial"
)

type Agent interface {
io.Closer
SetupComplete() bool
}

type Options struct {
EnableWireguard bool
UploadWireguardKeys UploadWireguardKeys
Expand Down Expand Up @@ -72,7 +77,7 @@ type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker
type UploadWireguardKeys func(ctx context.Context, keys WireguardPublicKeys) error
type ListenWireguardPeers func(ctx context.Context, logger slog.Logger) (<-chan peerwg.Handshake, func(), error)

func New(dialer Dialer, options *Options) io.Closer {
func New(dialer Dialer, options *Options) Agent {
if options == nil {
options = &Options{}
}
Expand Down Expand Up @@ -109,8 +114,10 @@ type agent struct {

envVars map[string]string
// metadata is atomic because values can change after reconnection.
metadata atomic.Value
startupScript atomic.Bool
metadata atomic.Value
// tracks whether or not we have started/completed initial setup, including any startup script
setupStarted atomic.Bool
setupComplete atomic.Bool
sshServer *ssh.Server

enableWireguard bool
Expand Down Expand Up @@ -147,15 +154,16 @@ func (a *agent) run(ctx context.Context) {
}
a.metadata.Store(metadata)

if a.startupScript.CAS(false, true) {
if a.setupStarted.CAS(false, true) {
// The startup script has not ran yet!
go func() {
err := a.runStartupScript(ctx, metadata.StartupScript)
defer a.setupComplete.Store(true)
err := a.performInitialSetup(ctx, &metadata)
if errors.Is(err, context.Canceled) {
return
}
if err != nil {
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
a.logger.Warn(ctx, "initial setup failed", slog.Error(err))
}
}()
}
Expand Down Expand Up @@ -184,6 +192,14 @@ func (a *agent) run(ctx context.Context) {
}
}

func (a *agent) performInitialSetup(ctx context.Context, metadata *Metadata) error {
err := a.runStartupScript(ctx, metadata.StartupScript)
if err != nil {
return xerrors.Errorf("agent script failed: %w", err)
}
return nil
}

func (a *agent) runStartupScript(ctx context.Context, script string) error {
if script == "" {
return nil
Expand Down Expand Up @@ -755,6 +771,10 @@ func (a *agent) Close() error {
return nil
}

func (a *agent) SetupComplete() bool {
return a.setupComplete.Load()
}

type reconnectingPTY struct {
activeConnsMutex sync.Mutex
activeConns map[string]net.Conn
Expand Down
34 changes: 15 additions & 19 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,23 +212,7 @@ func TestAgent(t *testing.T) {
StartupScript: fmt.Sprintf("echo %s > %s", content, tempPath),
}, 0)

var gotContent string
require.Eventually(t, func() bool {
content, err := os.ReadFile(tempPath)
if err != nil {
return false
}
if len(content) == 0 {
return false
}
if runtime.GOOS == "windows" {
// Windows uses UTF16! 🪟🪟🪟
content, _, err = transform.Bytes(unicode.UTF16(unicode.LittleEndian, unicode.UseBOM).NewDecoder(), content)
require.NoError(t, err)
}
gotContent = string(content)
return true
}, 15*time.Second, 100*time.Millisecond)
gotContent := readFileContents(t, tempPath)
require.Equal(t, content, strings.TrimSpace(gotContent))
})

Expand Down Expand Up @@ -436,7 +420,7 @@ func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {

func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn {
client, server := provisionersdk.TransportPipe()
closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) {
a := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) {
listener, err := peerbroker.Listen(server, nil)
return metadata, listener, err
}, &agent.Options{
Expand All @@ -446,7 +430,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
_ = closer.Close()
_ = a.Close()
})
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := api.NegotiateConnection(context.Background())
Expand All @@ -458,6 +442,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
t.Cleanup(func() {
_ = conn.Close()
})
require.Eventually(t, a.SetupComplete, 10*time.Second, 100*time.Millisecond)

return &agent.Conn{
Negotiator: api,
Expand Down Expand Up @@ -495,3 +480,14 @@ func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
assert.NoError(t, err, "write payload")
assert.Equal(t, len(payload), n, "payload length does not match")
}

func readFileContents(t *testing.T, path string) string {
content, err := os.ReadFile(path)
require.NoError(t, err)
if runtime.GOOS == "windows" {
// Windows uses UTF16! 🪟🪟🪟
content, _, err = transform.Bytes(unicode.UTF16(unicode.LittleEndian, unicode.UseBOM).NewDecoder(), content)
require.NoError(t, err)
}
return string(content)
}