|
4 | 4 | "context"
|
5 | 5 | "path/filepath"
|
6 | 6 | "runtime"
|
| 7 | + "slices" |
| 8 | + "sync" |
7 | 9 | "testing"
|
8 | 10 | "time"
|
9 | 11 |
|
@@ -151,6 +153,154 @@ func TestCronClose(t *testing.T) {
|
151 | 153 | require.NoError(t, runner.Close(), "close runner")
|
152 | 154 | }
|
153 | 155 |
|
| 156 | +func TestExecuteOptions(t *testing.T) { |
| 157 | + t.Parallel() |
| 158 | + |
| 159 | + startScript := codersdk.WorkspaceAgentScript{ |
| 160 | + ID: uuid.New(), |
| 161 | + LogSourceID: uuid.New(), |
| 162 | + Script: "echo start", |
| 163 | + RunOnStart: true, |
| 164 | + } |
| 165 | + stopScript := codersdk.WorkspaceAgentScript{ |
| 166 | + ID: uuid.New(), |
| 167 | + LogSourceID: uuid.New(), |
| 168 | + Script: "echo stop", |
| 169 | + RunOnStop: true, |
| 170 | + } |
| 171 | + postStartScript := codersdk.WorkspaceAgentScript{ |
| 172 | + ID: uuid.New(), |
| 173 | + LogSourceID: uuid.New(), |
| 174 | + Script: "echo poststart", |
| 175 | + } |
| 176 | + regularScript := codersdk.WorkspaceAgentScript{ |
| 177 | + ID: uuid.New(), |
| 178 | + LogSourceID: uuid.New(), |
| 179 | + Script: "echo regular", |
| 180 | + } |
| 181 | + |
| 182 | + scripts := []codersdk.WorkspaceAgentScript{ |
| 183 | + startScript, |
| 184 | + stopScript, |
| 185 | + regularScript, |
| 186 | + } |
| 187 | + allScripts := append(slices.Clone(scripts), postStartScript) |
| 188 | + |
| 189 | + scriptByID := func(id uuid.UUID) codersdk.WorkspaceAgentScript { |
| 190 | + for _, script := range allScripts { |
| 191 | + if script.ID == id { |
| 192 | + return script |
| 193 | + } |
| 194 | + } |
| 195 | + return codersdk.WorkspaceAgentScript{} |
| 196 | + } |
| 197 | + |
| 198 | + wantOutput := map[uuid.UUID]string{ |
| 199 | + startScript.ID: "start", |
| 200 | + stopScript.ID: "stop", |
| 201 | + postStartScript.ID: "poststart", |
| 202 | + regularScript.ID: "regular", |
| 203 | + } |
| 204 | + |
| 205 | + testCases := []struct { |
| 206 | + name string |
| 207 | + option agentscripts.ExecuteOption |
| 208 | + wantRun []uuid.UUID |
| 209 | + }{ |
| 210 | + { |
| 211 | + name: "ExecuteAllScripts", |
| 212 | + option: agentscripts.ExecuteAllScripts, |
| 213 | + wantRun: []uuid.UUID{startScript.ID, stopScript.ID, regularScript.ID, postStartScript.ID}, |
| 214 | + }, |
| 215 | + { |
| 216 | + name: "ExecuteStartScripts", |
| 217 | + option: agentscripts.ExecuteStartScripts, |
| 218 | + wantRun: []uuid.UUID{startScript.ID}, |
| 219 | + }, |
| 220 | + { |
| 221 | + name: "ExecutePostStartScripts", |
| 222 | + option: agentscripts.ExecutePostStartScripts, |
| 223 | + wantRun: []uuid.UUID{postStartScript.ID}, |
| 224 | + }, |
| 225 | + { |
| 226 | + name: "ExecuteStopScripts", |
| 227 | + option: agentscripts.ExecuteStopScripts, |
| 228 | + wantRun: []uuid.UUID{stopScript.ID}, |
| 229 | + }, |
| 230 | + } |
| 231 | + |
| 232 | + for _, tc := range testCases { |
| 233 | + t.Run(tc.name, func(t *testing.T) { |
| 234 | + t.Parallel() |
| 235 | + |
| 236 | + ctx := testutil.Context(t, testutil.WaitMedium) |
| 237 | + executedScripts := make(map[uuid.UUID]bool) |
| 238 | + fLogger := &filterTestLogger{ |
| 239 | + tb: t, |
| 240 | + executedScripts: executedScripts, |
| 241 | + wantOutput: wantOutput, |
| 242 | + } |
| 243 | + |
| 244 | + runner := setup(t, func(_ uuid.UUID) agentscripts.ScriptLogger { |
| 245 | + return fLogger |
| 246 | + }) |
| 247 | + defer runner.Close() |
| 248 | + |
| 249 | + aAPI := agenttest.NewFakeAgentAPI(t, testutil.Logger(t), nil, nil) |
| 250 | + err := runner.Init( |
| 251 | + scripts, |
| 252 | + aAPI.ScriptCompleted, |
| 253 | + agentscripts.WithPostStartScripts(postStartScript), |
| 254 | + ) |
| 255 | + require.NoError(t, err) |
| 256 | + |
| 257 | + err = runner.Execute(ctx, tc.option) |
| 258 | + require.NoError(t, err) |
| 259 | + |
| 260 | + gotRun := map[uuid.UUID]bool{} |
| 261 | + for _, id := range tc.wantRun { |
| 262 | + gotRun[id] = true |
| 263 | + require.True(t, executedScripts[id], |
| 264 | + "script %s should have run when using filter %s", scriptByID(id).Script, tc.name) |
| 265 | + } |
| 266 | + |
| 267 | + for _, script := range allScripts { |
| 268 | + if _, ok := gotRun[script.ID]; ok { |
| 269 | + continue |
| 270 | + } |
| 271 | + require.False(t, executedScripts[script.ID], |
| 272 | + "script %s should not have run when using filter %s", script.Script, tc.name) |
| 273 | + } |
| 274 | + }) |
| 275 | + } |
| 276 | +} |
| 277 | + |
| 278 | +type filterTestLogger struct { |
| 279 | + tb testing.TB |
| 280 | + executedScripts map[uuid.UUID]bool |
| 281 | + wantOutput map[uuid.UUID]string |
| 282 | + mu sync.Mutex |
| 283 | +} |
| 284 | + |
| 285 | +func (l *filterTestLogger) Send(ctx context.Context, logs ...agentsdk.Log) error { |
| 286 | + l.mu.Lock() |
| 287 | + defer l.mu.Unlock() |
| 288 | + for _, log := range logs { |
| 289 | + l.tb.Log(log.Output) |
| 290 | + for id, output := range l.wantOutput { |
| 291 | + if log.Output == output { |
| 292 | + l.executedScripts[id] = true |
| 293 | + break |
| 294 | + } |
| 295 | + } |
| 296 | + } |
| 297 | + return nil |
| 298 | +} |
| 299 | + |
| 300 | +func (l *filterTestLogger) Flush(context.Context) error { |
| 301 | + return nil |
| 302 | +} |
| 303 | + |
154 | 304 | func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscripts.ScriptLogger) *agentscripts.Runner {
|
155 | 305 | t.Helper()
|
156 | 306 | if getScriptLogger == nil {
|
|
0 commit comments