diff --git a/internal/oci/container.go b/internal/oci/container.go index 4e85c0013cf..87273a6e736 100644 --- a/internal/oci/container.go +++ b/internal/oci/container.go @@ -71,10 +71,9 @@ type Container struct { created bool spoofed bool stopping bool - stopTimeoutChan chan time.Duration - stoppedChan chan struct{} - stopStoppingChan chan struct{} stopLock sync.Mutex + stopTimeoutChan chan int64 + stopWatchers []chan struct{} pidns nsmgr.Namespace restore bool restoreArchive string @@ -136,21 +135,20 @@ func NewContainer(id, name, bundlePath, logPath string, labels, crioAnnotations, }, ImageRef: imageRef, }, - name: name, - bundlePath: bundlePath, - logPath: logPath, - terminal: terminal, - stdin: stdin, - stdinOnce: stdinOnce, - runtimeHandler: runtimeHandler, - crioAnnotations: crioAnnotations, - imageName: imageName, - dir: dir, - state: state, - stopSignal: stopSignal, - stopTimeoutChan: make(chan time.Duration, 1), - stoppedChan: make(chan struct{}, 1), - stopStoppingChan: make(chan struct{}, 1), + name: name, + bundlePath: bundlePath, + logPath: logPath, + terminal: terminal, + stdin: stdin, + stdinOnce: stdinOnce, + runtimeHandler: runtimeHandler, + crioAnnotations: crioAnnotations, + imageName: imageName, + dir: dir, + state: state, + stopSignal: stopSignal, + stopTimeoutChan: make(chan int64, 10), + stopWatchers: []chan struct{}{}, } return c, nil } @@ -499,9 +497,9 @@ func (c *Container) exitFilePath() string { return filepath.Join(c.dir, "exit") } -// IsAlive is a function that checks if a container's init PID exists. +// Living is a function that checks if a container's init PID exists. // It is used to check a container state when we don't want a `$runtime state` call -func (c *Container) IsAlive() error { +func (c *Container) Living() error { if _, err := c.pid(); err != nil { return fmt.Errorf("checking if PID of %s is running failed: %w", c.ID(), err) } @@ -603,7 +601,7 @@ func GetPidStartTimeFromFile(file string) (string, error) { // a container is not stoppable if it's paused or stopped // if it's paused, that's an error, and is reported as such func (c *Container) ShouldBeStopped() error { - switch c.state.Status { + switch c.State().Status { case ContainerStateStopped: // no-op return ErrContainerStopped case ContainerStatePaused: @@ -621,41 +619,34 @@ func (c *Container) Spoofed() bool { } // SetAsStopping marks a container as being stopped. -// If a stop is currently happening, it also sends the new timeout -// along the stopTimeoutChan, allowing the in-progress stop -// to stop faster, or ignore the new stop timeout. -// In this case, it also returns true, signifying the caller doesn't have to -// Do any stop related cleanup, as the original caller (alreadyStopping=false) -// will do said cleanup. -func (c *Container) SetAsStopping(timeout int64) (alreadyStopping bool) { - // First, need to check if the container is already stopping +// Returns true if the container was not set as stopping before, and false otherwise (i.e. on subsequent calls)." +func (c *Container) SetAsStopping() (setToStopping bool) { c.stopLock.Lock() defer c.stopLock.Unlock() - if c.stopping { - // If so, we shouldn't wait forever on the opLock. - // This can cause issues where the container stop gets DOSed by a very long - // timeout, followed a shorter one coming in. - // Instead, interrupt the other stop with this new one. - select { - case c.stopTimeoutChan <- time.Duration(timeout) * time.Second: - case <-c.stoppedChan: // This case is to avoid waiting forever once another routine has finished. - case <-c.stopStoppingChan: // This case is to avoid deadlocking with SetAsNotStopping. - } + if !c.stopping { + c.stopping = true return true } - // Regardless, set the container as actively stopping. - c.stopping = true - // And reset the stopStoppingChan - c.stopStoppingChan = make(chan struct{}, 1) return false } -// SetAsNotStopping unsets the stopping field indicating to new callers that the container -// is no longer actively stopping. -func (c *Container) SetAsNotStopping() { +func (c *Container) WaitOnStopTimeout(ctx context.Context, timeout int64) { c.stopLock.Lock() - c.stopping = false + if !c.stopping { + c.stopLock.Unlock() + return + } + + c.stopTimeoutChan <- timeout + + watcher := make(chan struct{}, 1) + c.stopWatchers = append(c.stopWatchers, watcher) c.stopLock.Unlock() + + select { + case <-ctx.Done(): + case <-watcher: + } } func (c *Container) AddManagedPIDNamespace(ns nsmgr.Namespace) { diff --git a/internal/oci/container_test.go b/internal/oci/container_test.go index 438e392f1e3..b9f6796bdf8 100644 --- a/internal/oci/container_test.go +++ b/internal/oci/container_test.go @@ -461,14 +461,14 @@ var _ = t.Describe("Container", func() { Expect(err).To(BeNil()) }) }) - t.Describe("IsAlive", func() { + t.Describe("Living", func() { It("should be false if pid uninitialized", func() { // Given state := &oci.ContainerState{} state.Pid = 0 sut.SetState(state) // When - err := sut.IsAlive() + err := sut.Living() // Then Expect(err).NotTo(BeNil()) @@ -480,7 +480,7 @@ var _ = t.Describe("Container", func() { Expect(state.SetInitPid(state.Pid)).To(BeNil()) sut.SetState(state) // When - err := sut.IsAlive() + err := sut.Living() // Then Expect(err).To(BeNil()) @@ -493,7 +493,7 @@ var _ = t.Describe("Container", func() { Expect(state.SetInitPid(state.Pid)).NotTo(BeNil()) sut.SetState(state) // When - err := sut.IsAlive() + err := sut.Living() // Then Expect(err).NotTo(BeNil()) diff --git a/internal/oci/container_test_inject.go b/internal/oci/container_test_inject.go index 215dab4e511..db0fdaf4bd5 100644 --- a/internal/oci/container_test_inject.go +++ b/internal/oci/container_test_inject.go @@ -6,6 +6,10 @@ package oci +import ( + "github.com/cri-o/cri-o/pkg/config" +) + // SetState sets the container state func (c *Container) SetState(state *ContainerState) { c.state = state @@ -24,3 +28,17 @@ func (c *Container) SetStateAndSpoofPid(state *ContainerState) { } c.state = state } + +type RuntimeOCI struct { + *runtimeOCI +} + +func NewRuntimeOCI(r *Runtime, handler *config.RuntimeHandler) RuntimeOCI { + return RuntimeOCI{ + runtimeOCI: &runtimeOCI{ + Runtime: r, + root: handler.RuntimeRoot, + handler: handler, + }, + } +} diff --git a/internal/oci/runtime_oci.go b/internal/oci/runtime_oci.go index 68fac040123..0ea0b7aa91f 100644 --- a/internal/oci/runtime_oci.go +++ b/internal/oci/runtime_oci.go @@ -784,148 +784,114 @@ func (r *runtimeOCI) UpdateContainer(ctx context.Context, c *Container, res *rsp return nil } -func WaitContainerStop(ctx context.Context, c *Container, timeout time.Duration, ignoreKill bool) error { +// StopContainer stops a container. Timeout is given in seconds. +func (r *runtimeOCI) StopContainer(ctx context.Context, c *Container, timeout int64) (retErr error) { + ctx, span := log.StartSpan(ctx) + defer span.End() + + if c.Spoofed() { + c.state.Status = ContainerStateStopped + c.state.Finished = time.Now() + return nil + } + + if err := c.ShouldBeStopped(); err != nil { + if errors.Is(err, ErrContainerStopped) { + err = nil + } + return err + } + + // The initial container process either doesn't exist, or isn't ours. + if err := c.Living(); err != nil { + c.state.Finished = time.Now() + return nil + } + + if c.SetAsStopping() { + go r.StopLoopForContainer(c) + } + + c.WaitOnStopTimeout(ctx, timeout) + return nil +} + +func (r *runtimeOCI) StopLoopForContainer(c *Container) { + ctx := context.Background() ctx, span := log.StartSpan(ctx) defer span.End() + c.opLock.Lock() + + // Begin the actual kill + if _, err := r.runtimeCmd("kill", c.ID(), c.GetStopSignal()); err != nil { + if err := c.Living(); err != nil { + // The initial container process either doesn't exist, or isn't ours. + // Set state accordingly. + c.state.Finished = time.Now() + c.opLock.Unlock() + return + } + } + done := make(chan struct{}) - // we could potentially re-use "done" channel to exit the loop on timeout, - // but we use another channel "chControl" so that we never panic - // attempting to close an already-closed "done" channel. The panic - // would occur in the "default" select case below if we'd closed the - // "done" channel (instead of the "chControl" channel) in the timeout - // select case. - chControl := make(chan struct{}) go func() { for { - select { - case <-chControl: + if err := c.Living(); err != nil { + // The initial container process either doesn't exist, or isn't ours. + if !errors.Is(err, ErrNotFound) { + log.Warnf(ctx, "Failed to find process for container %s: %v", c.ID(), err) + } close(done) return - default: - if err := c.verifyPid(); err != nil { - // The initial container process either doesn't exist, or isn't ours. - if !errors.Is(err, ErrNotFound) { - log.Warnf(ctx, "Failed to find process for container %s: %v", c.ID(), err) - } - close(done) - return - } - // the PID is still active and belongs to the container, continue to wait - time.Sleep(100 * time.Millisecond) } + // the PID is still active and belongs to the container, continue to wait + time.Sleep(100 * time.Millisecond) } }() + // Operate in terms of targetTime, so that we can pause in the middle of the operation // to catch a new timeout (and possibly ignore that new timeout if it's not correct to // take a new one). - targetTime := time.Now().Add(timeout) - killed := false - for !killed { + targetTime := time.Unix(1<<50-1, 0) + for finished := false; !finished; { select { - case <-done: - return nil - case <-ctx.Done(): - close(chControl) - return ctx.Err() - case <-time.After(time.Until(targetTime)): - close(chControl) - if ignoreKill { - return fmt.Errorf("timeout reached after %.0f seconds waiting for container process to exit", - timeout.Seconds()) - } - pid, err := c.pid() - if err != nil { - return err - } - if err := Kill(pid); err != nil { - return fmt.Errorf("failed to kill process: %w", err) - } - killed = true case newTimeout := <-c.stopTimeoutChan: // If a new timeout comes in, // interrupt the old one, and start a new one - newTargetTime := time.Now().Add(newTimeout) + newTargetTime := time.Now().Add(time.Duration(newTimeout) * time.Second) // but only if it's earlier - if newTargetTime.After(targetTime) { - continue + if newTargetTime.Before(targetTime) { + targetTime = newTargetTime } - targetTime = newTargetTime - timeout = newTimeout - } - } - c.state.Finished = time.Now() - // Successfully stopped! This is to prevent other routines from - // racing with this one and waiting forever. - // Close only the dedicated channel. If we close stopTimeoutChan, - // any other waiting goroutine will panic, not gracefully exit. - close(c.stoppedChan) - return nil -} - -// StopContainer stops a container. Timeout is given in seconds. -func (r *runtimeOCI) StopContainer(ctx context.Context, c *Container, timeout int64) (retErr error) { - ctx, span := log.StartSpan(ctx) - defer span.End() - if c.SetAsStopping(timeout) { - return nil - } - defer func() { - // Failed to stop, set stopping to false. - // Otherwise, we won't actually - // attempt to stop when a new request comes in, - // even though we're not actively stopping anymore. - // Also, close the stopStoppingChan to tell - // routines waiting to change the stop timeout to give up. - close(c.stopStoppingChan) - c.SetAsNotStopping() - }() - - c.opLock.Lock() - defer c.opLock.Unlock() - - if err := c.ShouldBeStopped(); err != nil { - return err - } - - if c.Spoofed() { - c.state.Status = ContainerStateStopped - c.state.Finished = time.Now() - return nil - } - - // The initial container process either doesn't exist, or isn't ours. - if err := c.verifyPid(); err != nil { - c.state.Finished = time.Now() - return nil - } + case <-time.After(time.Until(targetTime)): + log.Warnf(ctx, "Stopping container %v with stop signal timed out. Killing", c.ID()) + if _, err := r.runtimeCmd("kill", c.ID(), "KILL"); err != nil { + log.Errorf(ctx, "Killing container %v failed: %v", c.ID(), err) + } + if err := c.Living(); err != nil { + finished = true + break + } - if timeout > 0 { - if _, err := r.runtimeCmd("kill", c.ID(), c.GetStopSignal()); err != nil { - checkProcessGone(c) - } - err := WaitContainerStop(ctx, c, time.Duration(timeout)*time.Second, true) - if err == nil { - return nil + case <-done: + finished = true + break } - log.Warnf(ctx, "Stopping container %v with stop signal timed out: %v", c.ID(), err) } - if _, err := r.runtimeCmd("kill", c.ID(), "KILL"); err != nil { - checkProcessGone(c) - } - - return WaitContainerStop(ctx, c, killContainerTimeout, false) -} + c.state.Finished = time.Now() + c.opLock.Unlock() -func checkProcessGone(c *Container) { - if err := c.verifyPid(); err != nil { - // The initial container process either doesn't exist, or isn't ours. - // Set state accordingly. - c.state.Finished = time.Now() + c.stopLock.Lock() + for _, watcher := range c.stopWatchers { + close(watcher) } + c.stopping = false + close(c.stopTimeoutChan) + c.stopLock.Unlock() } // DeleteContainer deletes a container. diff --git a/internal/oci/runtime_oci_test.go b/internal/oci/runtime_oci_test.go index 4c8a80ff04a..b7dd1023776 100644 --- a/internal/oci/runtime_oci_test.go +++ b/internal/oci/runtime_oci_test.go @@ -2,31 +2,35 @@ package oci_test import ( "context" - "fmt" "math/rand" "os" "os/exec" "time" "github.com/cri-o/cri-o/internal/oci" + libconfig "github.com/cri-o/cri-o/pkg/config" + runnerMock "github.com/cri-o/cri-o/test/mocks/cmdrunner" + "github.com/cri-o/cri-o/utils/cmdrunner" + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) const ( shortTimeout int64 = 1 - mediumTimeout int64 = 3 + mediumTimeout int64 = 5 longTimeout int64 = 15 ) // The actual test suite var _ = t.Describe("Oci", func() { - Context("ContainerStop", func() { + Context("StopContainer", func() { var ( sut *oci.Container sleepProcess *exec.Cmd + runner *runnerMock.MockCommandRunner + runtime oci.RuntimeOCI ) - BeforeEach(func() { sleepProcess = exec.Command("sleep", "100000") Expect(sleepProcess.Start()).To(BeNil()) @@ -38,6 +42,15 @@ var _ = t.Describe("Oci", func() { state.Pid = sleepProcess.Process.Pid Expect(state.SetInitPid(sleepProcess.Process.Pid)).To(BeNil()) sut.SetState(state) + + runner = runnerMock.NewMockCommandRunner(mockCtrl) + cmdrunner.SetMocked(runner) + + cfg, err := libconfig.DefaultConfig() + Expect(err).To(BeNil()) + r, err := oci.New(cfg) + Expect(err).To(BeNil()) + runtime = oci.NewRuntimeOCI(r, &libconfig.RuntimeHandler{}) }) AfterEach(func() { // nolint:errcheck @@ -45,103 +58,125 @@ var _ = t.Describe("Oci", func() { // make sure the entry in the process table is cleaned up // nolint:errcheck sleepProcess.Wait() + cmdrunner.ResetPrependedCmd() }) - tests := []struct { - ignoreKill bool - verifyCorrectlyStopped func(*oci.Container, *exec.Cmd, error) - name string - }{ - { - ignoreKill: true, - verifyCorrectlyStopped: verifyContainerNotStopped, - name: "ignoring kill", - }, - { - ignoreKill: false, - verifyCorrectlyStopped: verifyContainerStopped, - name: "not ignoring kill", - }, - } - for _, test := range tests { - test := test - It("should stop container after timeout if "+test.name, func() { - // Given - sut.SetAsStopping(shortTimeout) - // When - err := oci.WaitContainerStop(context.Background(), sut, inSeconds(shortTimeout), test.ignoreKill) + It("should fail to stop if container paused", func() { + state := &oci.ContainerState{} + state.Status = oci.ContainerStatePaused + sut.SetState(state) - // Then - test.verifyCorrectlyStopped(sut, sleepProcess, err) - }) - It("should interrupt longer stop timeout if "+test.name, func() { - // Given - stoppedChan := make(chan error, 1) - sut.SetAsStopping(longTimeout) - go waitContainerStopAndFailAfterTimeout(context.Background(), stoppedChan, sut, longTimeout, longTimeout, test.ignoreKill) + Expect(sut.ShouldBeStopped()).NotTo(BeNil()) + }) + It("should fail to stop if container stopped", func() { + state := &oci.ContainerState{} + state.Status = oci.ContainerStateStopped + sut.SetState(state) - // When - sut.SetAsStopping(shortTimeout) + Expect(sut.ShouldBeStopped()).To(Equal(oci.ErrContainerStopped)) + }) + It("should return early if runtime command fails and process stopped", func() { + // Given + gomock.InOrder( + runner.EXPECT().Command(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ string, _ ...string) interface{} { + Expect(oci.Kill(sleepProcess.Process.Pid)).To(BeNil()) + waitForKillToComplete(sleepProcess) + return exec.Command("/bin/false") + }, + ), + ) - // Then - test.verifyCorrectlyStopped(sut, sleepProcess, <-stoppedChan) - }) - It("should handle being killed mid-timeout if "+test.name, func() { - // Given - stoppedChan := make(chan error, 1) - sut.SetAsStopping(longTimeout) - go waitContainerStopAndFailAfterTimeout(context.Background(), stoppedChan, sut, longTimeout, mediumTimeout, test.ignoreKill) - - // When - // nolint:errcheck - oci.Kill(sleepProcess.Process.Pid) - waitForKillToComplete(sleepProcess) + // When + sut.SetAsStopping() + runtime.StopLoopForContainer(sut) - // Then - // unconditionally expect the container was stopped - verifyContainerStopped(sut, sleepProcess, <-stoppedChan) - }) - It("should handle context timeout if "+test.name, func() { - // Given - ctx, cancel := context.WithCancel(context.Background()) - stoppedChan := make(chan error, 1) - sut.SetAsStopping(longTimeout) - go waitContainerStopAndFailAfterTimeout(ctx, stoppedChan, sut, longTimeout, mediumTimeout, test.ignoreKill) - - // When - cancel() - - // Then - // unconditionally expect the container was not stopped - verifyContainerNotStopped(sut, sleepProcess, <-stoppedChan) - }) - It("should not update time if chronologically after if "+test.name, func() { - // Given - stoppedChan := make(chan error, 1) - sut.SetAsStopping(mediumTimeout) - go waitContainerStopAndFailAfterTimeout(context.Background(), stoppedChan, sut, mediumTimeout, mediumTimeout, test.ignoreKill) + // Then + Expect(sut.State().Finished).NotTo(BeZero()) + verifyContainerStopped(sut, sleepProcess) + }) + It("should stop container before timeout", func() { + // Given + gomock.InOrder( + runner.EXPECT().Command(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ string, _ ...string) interface{} { + Expect(oci.Kill(sleepProcess.Process.Pid)).To(BeNil()) + waitForKillToComplete(sleepProcess) + return exec.Command("/bin/true") + }, + ), + ) + sut.SetAsStopping() + go runtime.StopLoopForContainer(sut) - // When - sut.SetAsStopping(longTimeout) + // Then + waitOnContainerTimeout(sut, longTimeout, mediumTimeout, sleepProcess) + }) + It("should fall back to KILL after timeout", func() { + // Given + containerIgnoreSignalCmdrunnerMock(sleepProcess, runner) + sut.SetAsStopping() + go runtime.StopLoopForContainer(sut) - // Then - test.verifyCorrectlyStopped(sut, sleepProcess, <-stoppedChan) - }) - It("should handle many updates if "+test.name, func() { - // Given - stoppedChan := make(chan error, 1) - sut.SetAsStopping(longTimeout) - go waitContainerStopAndFailAfterTimeout(context.Background(), stoppedChan, sut, longTimeout, longTimeout, test.ignoreKill) - - // When - for i := 0; i < 5; i++ { - go sut.SetAsStopping(int64(rand.Intn(10))) - } - - // Then - test.verifyCorrectlyStopped(sut, sleepProcess, <-stoppedChan) - }) - } + // Then + waitOnContainerTimeout(sut, shortTimeout, mediumTimeout, sleepProcess) + }) + It("should interrupt longer stop timeout", func() { + // Given + containerIgnoreSignalCmdrunnerMock(sleepProcess, runner) + sut.SetAsStopping() + go runtime.StopLoopForContainer(sut) + go sut.WaitOnStopTimeout(context.Background(), longTimeout) + + // Then + waitOnContainerTimeout(sut, shortTimeout, mediumTimeout, sleepProcess) + }) + + It("should not update time if chronologically after", func() { + // Given + containerIgnoreSignalCmdrunnerMock(sleepProcess, runner) + sut.SetAsStopping() + go runtime.StopLoopForContainer(sut) + + // When + shortStopChan := stopTimeoutWithChannel(context.Background(), sut, shortTimeout) + + // Then + waitOnContainerTimeout(sut, mediumTimeout, longTimeout, sleepProcess) + <-shortStopChan + }) + It("should handle many updates", func() { + // Given + containerIgnoreSignalCmdrunnerMock(sleepProcess, runner) + sut.SetAsStopping() + go runtime.StopLoopForContainer(sut) + // very long timeout + stoppedChan := stopTimeoutWithChannel(context.Background(), sut, longTimeout*10) + + // When + for i := 0; i < 10; i++ { + go sut.WaitOnStopTimeout(context.Background(), int64(rand.Intn(100)+20)) + time.Sleep(time.Second) + } + sut.WaitOnStopTimeout(context.Background(), mediumTimeout) + + // Then + <-stoppedChan + verifyContainerStopped(sut, sleepProcess) + }) + It("should handle context timeout", func() { + // Given + ctx, cancel := context.WithCancel(context.Background()) + stoppedChan := stopTimeoutWithChannel(ctx, sut, shortTimeout) + + // When + cancel() + + // Then + // unconditionally expect the container was not stopped + <-stoppedChan + verifyContainerNotStopped(sut) + }) }) Context("TruncateAndReadFile", func() { tests := []struct { @@ -183,23 +218,44 @@ var _ = t.Describe("Oci", func() { }) }) -func waitContainerStopAndFailAfterTimeout(ctx context.Context, - stoppedChan chan error, - sut *oci.Container, - waitContainerStopTimeout int64, - failAfterTimeout int64, - ignoreKill bool, -) { +func containerIgnoreSignalCmdrunnerMock(sleepProcess *exec.Cmd, runner *runnerMock.MockCommandRunner) { + gomock.InOrder( + runner.EXPECT().Command(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ string, _ ...string) interface{} { + return exec.Command("/bin/true") + }, + ), + runner.EXPECT().Command(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ string, _ ...string) interface{} { + Expect(oci.Kill(sleepProcess.Process.Pid)).To(BeNil()) + waitForKillToComplete(sleepProcess) + return exec.Command("/bin/true") + }, + ), + ) +} + +func waitOnContainerTimeout(sut *oci.Container, stopTimeout, waitTimeout int64, sleepProcess *exec.Cmd) { + stoppedChan := stopTimeoutWithChannel(context.Background(), sut, stopTimeout) + select { - case stoppedChan <- oci.WaitContainerStop(ctx, sut, inSeconds(waitContainerStopTimeout), ignoreKill): - case <-time.After(inSeconds(failAfterTimeout)): - stoppedChan <- fmt.Errorf("%d seconds passed, container kill should have been recognized", failAfterTimeout) + case <-stoppedChan: + case <-time.After(time.Second * time.Duration(waitTimeout)): + Fail("did not timeout quickly enough") } - close(stoppedChan) + verifyContainerStopped(sut, sleepProcess) +} + +func stopTimeoutWithChannel(ctx context.Context, sut *oci.Container, timeout int64) chan struct{} { + stoppedChan := make(chan struct{}, 1) + go func() { + sut.WaitOnStopTimeout(ctx, timeout) + close(stoppedChan) + }() + return stoppedChan } -func verifyContainerStopped(sut *oci.Container, sleepProcess *exec.Cmd, waitError error) { - Expect(waitError).To(BeNil()) +func verifyContainerStopped(sut *oci.Container, sleepProcess *exec.Cmd) { waitForKillToComplete(sleepProcess) pid, err := sut.Pid() Expect(pid).To(Equal(0)) @@ -213,8 +269,7 @@ func waitForKillToComplete(sleepProcess *exec.Cmd) { time.Sleep(inSeconds(shortTimeout)) } -func verifyContainerNotStopped(sut *oci.Container, _ *exec.Cmd, waitError error) { - Expect(waitError).NotTo(BeNil()) +func verifyContainerNotStopped(sut *oci.Container) { pid, err := sut.Pid() Expect(pid).NotTo(Equal(0)) Expect(err).To(BeNil()) diff --git a/internal/oci/runtime_vm.go b/internal/oci/runtime_vm.go index b04e2b6c11b..27b14fe46cb 100644 --- a/internal/oci/runtime_vm.go +++ b/internal/oci/runtime_vm.go @@ -521,14 +521,17 @@ func (r *runtimeVM) StopContainer(ctx context.Context, c *Container, timeout int log.Debugf(ctx, "RuntimeVM.StopContainer() start") defer log.Debugf(ctx, "RuntimeVM.StopContainer() end") - // Lock the container - c.opLock.Lock() - defer c.opLock.Unlock() - if err := c.ShouldBeStopped(); err != nil { + if errors.Is(err, ErrContainerStopped) { + err = nil + } return err } + // Lock the container + c.opLock.Lock() + defer c.opLock.Unlock() + // Cancel the context before returning to ensure goroutines are stopped. ctx, cancel := context.WithCancel(r.ctx) defer cancel() diff --git a/server/container_execsync.go b/server/container_execsync.go index 03a9246a197..0f0093cefba 100644 --- a/server/container_execsync.go +++ b/server/container_execsync.go @@ -19,7 +19,7 @@ func (s *Server) ExecSync(ctx context.Context, req *types.ExecSyncRequest) (*typ return nil, status.Errorf(codes.NotFound, "could not find container %q: %v", req.ContainerId, err) } - if err := c.IsAlive(); err != nil { + if err := c.Living(); err != nil { return nil, status.Errorf(codes.NotFound, "container is not created or running: %v", err) } diff --git a/server/container_remove.go b/server/container_remove.go index 57bda0c7af4..db7070d0bd1 100644 --- a/server/container_remove.go +++ b/server/container_remove.go @@ -53,7 +53,7 @@ func (s *Server) removeContainerInPod(ctx context.Context, sb *sandbox.Sandbox, defer span.End() if !sb.Stopped() { if err := s.stopContainer(ctx, c, int64(10)); err != nil { - return fmt.Errorf("failed to stop container for removal") + return fmt.Errorf("failed to stop container for removal %w", err) } } diff --git a/server/container_remove_test.go b/server/container_remove_test.go index b20181875b0..2b06c88af18 100644 --- a/server/container_remove_test.go +++ b/server/container_remove_test.go @@ -35,6 +35,9 @@ var _ = t.Describe("ContainerRemove", func() { runtimeServerMock.EXPECT().DeleteContainer(gomock.Any(), gomock.Any()). Return(nil), ) + // This allows us to skip stopContainer() which fails because we don't + // spoof the `runtime state` call in `UpdateContainerStatus` + testSandbox.SetStopped(context.Background(), false) // When _, err := sut.RemoveContainer(context.Background(), diff --git a/server/container_stop.go b/server/container_stop.go index 495468afee9..f44e83e02a8 100644 --- a/server/container_stop.go +++ b/server/container_stop.go @@ -58,21 +58,13 @@ func (s *Server) stopContainer(ctx context.Context, ctr *oci.Container, timeout } if err := s.Runtime().StopContainer(ctx, ctr, timeout); err != nil { - // only fatally error if the error is not that the container was already stopped - // we still want to write container state to disk if the container has already - // been stopped - if err != oci.ErrContainerStopped { - return fmt.Errorf("failed to stop container %s: %w", ctr.ID(), err) - } - } else { - // we only do these operations if StopContainer didn't fail (even if the failure - // was the container already being stopped) - if err := s.Runtime().UpdateContainerStatus(ctx, ctr); err != nil { - return fmt.Errorf("failed to update container status %s: %w", ctr.ID(), err) - } - if err := s.StorageRuntimeServer().StopContainer(ctx, ctr.ID()); err != nil { - return fmt.Errorf("failed to unmount container %s: %w", ctr.ID(), err) - } + return fmt.Errorf("failed to stop container %s: %w", ctr.ID(), err) + } + if err := s.Runtime().UpdateContainerStatus(ctx, ctr); err != nil { + return fmt.Errorf("failed to update container status %s: %w", ctr.ID(), err) + } + if err := s.StorageRuntimeServer().StopContainer(ctx, ctr.ID()); err != nil { + return fmt.Errorf("failed to unmount container %s: %w", ctr.ID(), err) } if err := s.ContainerStateToDisk(ctx, ctr); err != nil { diff --git a/server/container_stop_test.go b/server/container_stop_test.go index c2378c0aaba..d209e08b797 100644 --- a/server/container_stop_test.go +++ b/server/container_stop_test.go @@ -4,6 +4,7 @@ import ( "context" "github.com/cri-o/cri-o/internal/oci" + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -27,6 +28,10 @@ var _ = t.Describe("ContainerStop", func() { testContainer.SetState(&oci.ContainerState{ State: specs.State{Status: oci.ContainerStateStopped}, }) + gomock.InOrder( + runtimeServerMock.EXPECT().StopContainer(gomock.Any(), gomock.Any()). + Return(nil), + ) // When _, err := sut.StopContainer(context.Background(), diff --git a/server/sandbox_stop_linux.go b/server/sandbox_stop_linux.go index 4c28d54ab4b..6da27ad29e1 100644 --- a/server/sandbox_stop_linux.go +++ b/server/sandbox_stop_linux.go @@ -72,7 +72,7 @@ func (s *Server) stopPodSandbox(ctx context.Context, sb *sandbox.Sandbox) error } } - if err := s.stopContainer(ctx, podInfraContainer, int64(10)); err != nil && !errors.Is(err, storage.ErrContainerUnknown) && !errors.Is(err, oci.ErrContainerStopped) { + if err := s.stopContainer(ctx, podInfraContainer, int64(10)); err != nil && !errors.Is(err, storage.ErrContainerUnknown) { return fmt.Errorf("failed to stop infra container for pod sandbox %s: %v", sb.ID(), err) }