diff --git a/provisionersdk/agent_test.go b/provisionersdk/agent_test.go index 5dbdc41a89cd2..3be01e20dce6f 100644 --- a/provisionersdk/agent_test.go +++ b/provisionersdk/agent_test.go @@ -8,7 +8,6 @@ package provisionersdk_test import ( "bytes" - "context" "errors" "fmt" "net/http" @@ -47,12 +46,10 @@ func TestAgentScript(t *testing.T) { t.Run("Valid", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) script := serveScript(t, bashEcho) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - t.Cleanup(cancel) - - var output bytes.Buffer + var output safeBuffer // This is intentionally ran in single quotes to mimic how a customer may // embed our script. Our scripts should not include any single quotes. // nolint:gosec @@ -84,12 +81,10 @@ func TestAgentScript(t *testing.T) { t.Run("Invalid", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) script := serveScript(t, unexpectedEcho) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - t.Cleanup(cancel) - - var output bytes.Buffer + var output safeBuffer // This is intentionally ran in single quotes to mimic how a customer may // embed our script. Our scripts should not include any single quotes. // nolint:gosec @@ -159,3 +154,33 @@ func serveScript(t *testing.T, in string) string { script = strings.ReplaceAll(script, "${AUTH_TYPE}", "token") return script } + +// safeBuffer is a concurrency-safe bytes.Buffer +type safeBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (sb *safeBuffer) Write(p []byte) (n int, err error) { + sb.mu.Lock() + defer sb.mu.Unlock() + return sb.buf.Write(p) +} + +func (sb *safeBuffer) Read(p []byte) (n int, err error) { + sb.mu.Lock() + defer sb.mu.Unlock() + return sb.buf.Read(p) +} + +func (sb *safeBuffer) Bytes() []byte { + sb.mu.Lock() + defer sb.mu.Unlock() + return sb.buf.Bytes() +} + +func (sb *safeBuffer) String() string { + sb.mu.Lock() + defer sb.mu.Unlock() + return sb.buf.String() +}