Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c78665a

Browse files
committed
add agentscripts test for execute option
1 parent 9aca1c8 commit c78665a

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed

agent/agentscripts/agentscripts_test.go

+150
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"path/filepath"
66
"runtime"
7+
"slices"
8+
"sync"
79
"testing"
810
"time"
911

@@ -151,6 +153,154 @@ func TestCronClose(t *testing.T) {
151153
require.NoError(t, runner.Close(), "close runner")
152154
}
153155

156+
func TestExecuteOptions(t *testing.T) {
157+
t.Parallel()
158+
159+
startScript := codersdk.WorkspaceAgentScript{
160+
ID: uuid.New(),
161+
LogSourceID: uuid.New(),
162+
Script: "echo start",
163+
RunOnStart: true,
164+
}
165+
stopScript := codersdk.WorkspaceAgentScript{
166+
ID: uuid.New(),
167+
LogSourceID: uuid.New(),
168+
Script: "echo stop",
169+
RunOnStop: true,
170+
}
171+
postStartScript := codersdk.WorkspaceAgentScript{
172+
ID: uuid.New(),
173+
LogSourceID: uuid.New(),
174+
Script: "echo poststart",
175+
}
176+
regularScript := codersdk.WorkspaceAgentScript{
177+
ID: uuid.New(),
178+
LogSourceID: uuid.New(),
179+
Script: "echo regular",
180+
}
181+
182+
scripts := []codersdk.WorkspaceAgentScript{
183+
startScript,
184+
stopScript,
185+
regularScript,
186+
}
187+
allScripts := append(slices.Clone(scripts), postStartScript)
188+
189+
scriptByID := func(id uuid.UUID) codersdk.WorkspaceAgentScript {
190+
for _, script := range allScripts {
191+
if script.ID == id {
192+
return script
193+
}
194+
}
195+
return codersdk.WorkspaceAgentScript{}
196+
}
197+
198+
wantOutput := map[uuid.UUID]string{
199+
startScript.ID: "start",
200+
stopScript.ID: "stop",
201+
postStartScript.ID: "poststart",
202+
regularScript.ID: "regular",
203+
}
204+
205+
testCases := []struct {
206+
name string
207+
option agentscripts.ExecuteOption
208+
wantRun []uuid.UUID
209+
}{
210+
{
211+
name: "ExecuteAllScripts",
212+
option: agentscripts.ExecuteAllScripts,
213+
wantRun: []uuid.UUID{startScript.ID, stopScript.ID, regularScript.ID, postStartScript.ID},
214+
},
215+
{
216+
name: "ExecuteStartScripts",
217+
option: agentscripts.ExecuteStartScripts,
218+
wantRun: []uuid.UUID{startScript.ID},
219+
},
220+
{
221+
name: "ExecutePostStartScripts",
222+
option: agentscripts.ExecutePostStartScripts,
223+
wantRun: []uuid.UUID{postStartScript.ID},
224+
},
225+
{
226+
name: "ExecuteStopScripts",
227+
option: agentscripts.ExecuteStopScripts,
228+
wantRun: []uuid.UUID{stopScript.ID},
229+
},
230+
}
231+
232+
for _, tc := range testCases {
233+
t.Run(tc.name, func(t *testing.T) {
234+
t.Parallel()
235+
236+
ctx := testutil.Context(t, testutil.WaitMedium)
237+
executedScripts := make(map[uuid.UUID]bool)
238+
fLogger := &filterTestLogger{
239+
tb: t,
240+
executedScripts: executedScripts,
241+
wantOutput: wantOutput,
242+
}
243+
244+
runner := setup(t, func(_ uuid.UUID) agentscripts.ScriptLogger {
245+
return fLogger
246+
})
247+
defer runner.Close()
248+
249+
aAPI := agenttest.NewFakeAgentAPI(t, testutil.Logger(t), nil, nil)
250+
err := runner.Init(
251+
scripts,
252+
aAPI.ScriptCompleted,
253+
agentscripts.WithPostStartScripts(postStartScript),
254+
)
255+
require.NoError(t, err)
256+
257+
err = runner.Execute(ctx, tc.option)
258+
require.NoError(t, err)
259+
260+
gotRun := map[uuid.UUID]bool{}
261+
for _, id := range tc.wantRun {
262+
gotRun[id] = true
263+
require.True(t, executedScripts[id],
264+
"script %s should have run when using filter %s", scriptByID(id).Script, tc.name)
265+
}
266+
267+
for _, script := range allScripts {
268+
if _, ok := gotRun[script.ID]; ok {
269+
continue
270+
}
271+
require.False(t, executedScripts[script.ID],
272+
"script %s should not have run when using filter %s", script.Script, tc.name)
273+
}
274+
})
275+
}
276+
}
277+
278+
type filterTestLogger struct {
279+
tb testing.TB
280+
executedScripts map[uuid.UUID]bool
281+
wantOutput map[uuid.UUID]string
282+
mu sync.Mutex
283+
}
284+
285+
func (l *filterTestLogger) Send(ctx context.Context, logs ...agentsdk.Log) error {
286+
l.mu.Lock()
287+
defer l.mu.Unlock()
288+
for _, log := range logs {
289+
l.tb.Log(log.Output)
290+
for id, output := range l.wantOutput {
291+
if log.Output == output {
292+
l.executedScripts[id] = true
293+
break
294+
}
295+
}
296+
}
297+
return nil
298+
}
299+
300+
func (l *filterTestLogger) Flush(context.Context) error {
301+
return nil
302+
}
303+
154304
func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscripts.ScriptLogger) *agentscripts.Runner {
155305
t.Helper()
156306
if getScriptLogger == nil {

0 commit comments

Comments
 (0)