Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion core-tests/shared/src/test/scala/zio/SupervisorSpec.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package zio

import zio.Clock.ClockLive
import zio.test.TestAspect.{exceptJS, nonFlaky}
import zio.test._

import java.util.concurrent.atomic.AtomicInteger

object SupervisorSpec extends ZIOBaseSpec {

def spec = suite("SupervisorSpec")(
Expand Down Expand Up @@ -31,7 +34,27 @@ object SupervisorSpec extends ZIOBaseSpec {
DifferSpec.diffLaws(Differ.supervisor)(genSupervisor)((left, right) =>
Supervisor.toSet(left) == Supervisor.toSet(right)
)
}
},
suite("onStart and onEnd are called exactly once")(
test("sync effect") {
for {
s <- ZIO.succeed(new StartEndTrackingSupervisor)
f <- ZIO.unit.fork.supervised(s)
_ <- f.await
// onEnd might be called after the forked fiber notifies the current fiber
_ <- ZIO.succeed(s.onEndCalls).repeatUntil(_ > 0)
} yield assertTrue(s.onStartCalls == 1, s.onEndCalls == 1)
},
test("async effect") {
for {
s <- ZIO.succeed(new StartEndTrackingSupervisor)
f <- ClockLive.sleep(100.micros).fork.supervised(s)
_ <- f.await
// onEnd might be called after the forked fiber notifies the current fiber
_ <- ZIO.succeed(s.onEndCalls).repeatUntil(_ > 0)
} yield assertTrue(s.onStartCalls == 1, s.onEndCalls == 1)
}
) @@ TestAspect.nonFlaky(100)
)

val genSupervisor: Gen[Any, Supervisor[Any]] =
Expand Down Expand Up @@ -105,4 +128,29 @@ object SupervisorSpec extends ZIOBaseSpec {
promise.unsafe.done(ZIO.succeed(ref.unsafe.get))
}
}

private final class StartEndTrackingSupervisor extends Supervisor[Unit] {
private val _onStartCalls, _onEndCalls = new AtomicInteger(0)

def value(implicit trace: Trace): UIO[Unit] = ZIO.unit

def onStart[R, E, A](
environment: ZEnvironment[R],
effect: ZIO[R, E, A],
parent: Option[Fiber.Runtime[Any, Any]],
fiber: Fiber.Runtime[E, A]
)(implicit unsafe: Unsafe): Unit = {
_onStartCalls.incrementAndGet()
()
}

def onEnd[R, E, A](value: Exit[E, A], fiber: Fiber.Runtime[E, A])(implicit unsafe: Unsafe): Unit = {
_onEndCalls.incrementAndGet
()
}

def onStartCalls = _onStartCalls.get
def onEndCalls = _onEndCalls.get
}

}
4 changes: 0 additions & 4 deletions core/shared/src/main/scala/zio/Runtime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ trait Runtime[+R] { self =>

if (supervisor ne Supervisor.none) {
supervisor.onStart(environment, zio, None, fiber)

fiber.addObserver(exit => supervisor.onEnd(exit, fiber))
}

val exit = fiber.start[R](zio)
Expand Down Expand Up @@ -201,8 +199,6 @@ trait Runtime[+R] { self =>

if (supervisor ne Supervisor.none) {
supervisor.onStart(environment, zio, None, fiber)

fiber.addObserver(exit => supervisor.onEnd(exit, fiber))
}

fiber
Expand Down
2 changes: 0 additions & 2 deletions core/shared/src/main/scala/zio/ZIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2698,8 +2698,6 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific
Some(parentFiber),
childFiber
)

childFiber.addObserver(exit => supervisor.onEnd(exit, childFiber))
}

val parentScope =
Expand Down
7 changes: 3 additions & 4 deletions core/shared/src/main/scala/zio/internal/FiberRuntime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -411,14 +411,11 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs,
val exit =
runLoop(effect, 0, _stackSize, initialDepth, 0).asInstanceOf[Exit[E, A]]

if (null eq exit) {
if (exit eq null) {
// Terminate this evaluation, async resumption will continue evaluation:
_forksSinceYield = 0
effect = null
} else {

if (supervisor ne Supervisor.none) supervisor.onEnd(exit, self)(Unsafe)

self._runtimeFlags = RuntimeFlags.enable(_runtimeFlags)(RuntimeFlag.WindDown)

val interruption = interruptAllChildren()
Expand All @@ -427,6 +424,8 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs,
if (inbox.isEmpty) {
finalExit = exit

if (supervisor ne Supervisor.none) supervisor.onEnd(finalExit, self)(Unsafe)

// No more messages to process, so we will allow the fiber to end life:
self.setExitValue(exit)
} else {
Expand Down
Loading