From 50dc42b6be44eda93d3264d5778cc6126fc5c14d Mon Sep 17 00:00:00 2001 From: Jules Ivanic Date: Mon, 10 Feb 2025 17:01:58 +1100 Subject: [PATCH 1/3] Add internal `Promise#succeedUnit` method to avoid `Exit` allocation when possible Idea comes from @hearnadam's review in this PR: https://github.com/zio/zio/pull/9556 --- .../scala/zio/concurrent/CountdownLatch.scala | 2 +- .../scala/zio/concurrent/CyclicBarrier.scala | 2 +- .../scala/zio/concurrent/ReentrantLock.scala | 2 +- core/shared/src/main/scala/zio/Hub.scala | 4 ++-- core/shared/src/main/scala/zio/Promise.scala | 12 ++++++++++ .../shared/src/main/scala/zio/Semaphore.scala | 4 ++-- .../src/main/scala/zio/stream/ZChannel.scala | 22 +++++++++---------- .../src/main/scala/zio/stream/ZStream.scala | 12 +++++----- 8 files changed, 36 insertions(+), 24 deletions(-) diff --git a/concurrent/src/main/scala/zio/concurrent/CountdownLatch.scala b/concurrent/src/main/scala/zio/concurrent/CountdownLatch.scala index e2e116dc3157..14f66abf8770 100644 --- a/concurrent/src/main/scala/zio/concurrent/CountdownLatch.scala +++ b/concurrent/src/main/scala/zio/concurrent/CountdownLatch.scala @@ -38,7 +38,7 @@ final class CountdownLatch private (_count: Ref[Int], _waiters: Promise[Nothing, */ val countDown: UIO[Unit] = _count.modify { case 0 => ZIO.unit -> 0 - case 1 => _waiters.succeed(()) -> 0 + case 1 => _waiters.succeedUnit -> 0 case n => ZIO.unit -> (n - 1) }.flatten.unit diff --git a/concurrent/src/main/scala/zio/concurrent/CyclicBarrier.scala b/concurrent/src/main/scala/zio/concurrent/CyclicBarrier.scala index 3edb340806e1..40c699b34529 100644 --- a/concurrent/src/main/scala/zio/concurrent/CyclicBarrier.scala +++ b/concurrent/src/main/scala/zio/concurrent/CyclicBarrier.scala @@ -29,7 +29,7 @@ final class CyclicBarrier private ( _lock.get.flatMap(_.fail(()).unit) private val succeed: UIO[Unit] = - _lock.get.flatMap(_.succeed(()).unit) + _lock.get.flatMap(_.succeedUnit.unit) /** The number of parties required to trip this barrier. */ def parties: Int = _parties diff --git a/concurrent/src/main/scala/zio/concurrent/ReentrantLock.scala b/concurrent/src/main/scala/zio/concurrent/ReentrantLock.scala index ec929b620879..5beeeb845ba2 100644 --- a/concurrent/src/main/scala/zio/concurrent/ReentrantLock.scala +++ b/concurrent/src/main/scala/zio/concurrent/ReentrantLock.scala @@ -130,7 +130,7 @@ final class ReentrantLock private (fairness: Boolean, state: Ref[ReentrantLock.S ZIO.unit -> State(epoch + 1, None, 0, Map.empty) else { val (fiberId, (_, promise)) = if (fairness) holders.minBy(_._2._1) else pickRandom(holders) - promise.succeed(()).unit -> State(epoch + 1, Some(fiberId), 1, holders - fiberId) + promise.succeedUnit.unit -> State(epoch + 1, Some(fiberId), 1, holders - fiberId) } private def pickRandom( diff --git a/core/shared/src/main/scala/zio/Hub.scala b/core/shared/src/main/scala/zio/Hub.scala index 2950e7d6923d..c5affe3d29b7 100644 --- a/core/shared/src/main/scala/zio/Hub.scala +++ b/core/shared/src/main/scala/zio/Hub.scala @@ -160,7 +160,7 @@ object Hub { ZIO.fiberIdWith { fiberId => shutdownFlag.set(true) ZIO - .whenZIO(shutdownHook.succeed(())) { + .whenZIO(shutdownHook.succeedUnit) { scope.close(Exit.interrupt(fiberId)) *> strategy.shutdown } .unit @@ -229,7 +229,7 @@ object Hub { ZIO.fiberIdWith { fiberId => shutdownFlag.set(true) ZIO - .whenZIO(shutdownHook.succeed(())) { + .whenZIO(shutdownHook.succeedUnit) { ZIO.foreachPar(unsafePollAll(pollers))(_.interruptAs(fiberId)) *> ZIO.succeed { subscribers.remove(subscription -> pollers) diff --git a/core/shared/src/main/scala/zio/Promise.scala b/core/shared/src/main/scala/zio/Promise.scala index 73b21ec656f1..5c09c26bd3d1 100644 --- a/core/shared/src/main/scala/zio/Promise.scala +++ b/core/shared/src/main/scala/zio/Promise.scala @@ -176,6 +176,14 @@ final class Promise[E, A] private ( def succeed(a: A)(implicit trace: Trace): UIO[Boolean] = ZIO.succeed(unsafe.succeed(a)(trace, Unsafe.unsafe)) + /** + * Internally, you can use this method instead of calling `myPromise.succeed(())` + * + * It avoids the `Exit` allocation + */ + private[zio] def succeedUnit(implicit ev0: A =:= Unit, trace: Trace): UIO[Boolean] = + ZIO.succeed(unsafe.succeedUnit(ev0, trace, Unsafe)) + private def interruptJoiner(joiner: IO[E, A] => Any)(implicit trace: Trace): UIO[Any] = ZIO.succeed { var retry = true @@ -205,6 +213,7 @@ final class Promise[E, A] private ( def poll(implicit unsafe: Unsafe): Option[IO[E, A]] def refailCause(e: Cause[E])(implicit trace: Trace, unsafe: Unsafe): Boolean def succeed(a: A)(implicit trace: Trace, unsafe: Unsafe): Boolean + def succeedUnit(implicit ev0: A =:= Unit, trace: Trace, unsafe: Unsafe): Boolean } private[zio] val unsafe: UnsafeAPI = @@ -280,6 +289,9 @@ final class Promise[E, A] private ( def succeed(a: A)(implicit trace: Trace, unsafe: Unsafe): Boolean = completeWith(Exit.succeed(a)) + + override def succeedUnit(implicit ev0: A =:= Unit, trace: Trace, unsafe: Unsafe): Boolean = + completeWith(Exit.unit.asInstanceOf[IO[E, A]]) } } diff --git a/core/shared/src/main/scala/zio/Semaphore.scala b/core/shared/src/main/scala/zio/Semaphore.scala index f17ad2322220..fbad74f692c1 100644 --- a/core/shared/src/main/scala/zio/Semaphore.scala +++ b/core/shared/src/main/scala/zio/Semaphore.scala @@ -155,9 +155,9 @@ object Semaphore { case None => acc -> Right(n) case Some(((promise, permits), queue)) => if (n > permits) - loop(n - permits, Left(queue), acc *> promise.succeed(())) + loop(n - permits, Left(queue), acc *> promise.succeedUnit) else if (n == permits) - (acc *> promise.succeed(())) -> Left(queue) + (acc *> promise.succeedUnit) -> Left(queue) else acc -> Left((promise -> (permits - n)) +: queue) } diff --git a/streams/shared/src/main/scala/zio/stream/ZChannel.scala b/streams/shared/src/main/scala/zio/stream/ZChannel.scala index 46d03df56bc9..65f17137403b 100644 --- a/streams/shared/src/main/scala/zio/stream/ZChannel.scala +++ b/streams/shared/src/main/scala/zio/stream/ZChannel.scala @@ -689,11 +689,11 @@ sealed trait ZChannel[-Env, -InErr, -InElem, -InDone, +OutErr, +OutElem, +OutDon permits .withPermit( - latch.succeed(()) *> + latch.succeedUnit *> f(outElem) .catchAllCause(cause => failureRef.update(_ && cause).unless(cause.isInterruptedOnly) *> - errorSignal.succeed(()) *> + errorSignal.succeedUnit *> ZChannel.failLeftUnit ) ) @@ -770,11 +770,11 @@ sealed trait ZChannel[-Env, -InErr, -InElem, -InDone, +OutErr, +OutElem, +OutDon for { _ <- permits .withPermit( - latch.succeed(()) *> f(outElem) + latch.succeedUnit *> f(outElem) .foldCauseZIO( cause => failure.update(_ && cause).unless(cause.isInterruptedOnly) *> - errorSignal.succeed(()) *> + errorSignal.succeedUnit *> outgoing.offer(ZChannel.failLeftUnit), elem => outgoing.offer(Exit.succeed(elem)) ) @@ -1248,8 +1248,8 @@ sealed trait ZChannel[-Env, -InErr, -InElem, -InDone, +OutErr, +OutElem, +OutDon fiber <- restore(run(channelPromise, scopePromise, child)).forkDaemon _ <- parent.addFinalizer { channelPromise.isDone.flatMap { isDone => - if (isDone) scopePromise.succeed(()) *> fiber.await *> fiber.inheritAll - else scopePromise.succeed(()) *> fiber.interrupt *> fiber.inheritAll + if (isDone) scopePromise.succeedUnit *> fiber.await *> fiber.inheritAll + else scopePromise.succeedUnit *> fiber.interrupt *> fiber.inheritAll } } done <- restore(channelPromise.await) @@ -1986,9 +1986,9 @@ object ZChannel { } } case Left(l: Left[OutErr, OutDone]) => - outgoing.offer(Result.Error(l.value)) *> errorSignal.succeed(()).unit + outgoing.offer(Result.Error(l.value)) *> errorSignal.succeedUnit.unit case Right(cause) => - outgoing.offer(Result.Fatal(cause)) *> errorSignal.succeed(()).unit + outgoing.offer(Result.Fatal(cause)) *> errorSignal.succeedUnit.unit } ) @@ -2005,7 +2005,7 @@ object ZChannel { } permits - .withPermit(latch.succeed(()) *> raceIOs) + .withPermit(latch.succeedUnit *> raceIOs) .interruptible .forkIn(childScope) *> latch.await } @@ -2017,7 +2017,7 @@ object ZChannel { for { size <- cancelers.size - _ <- ZIO.when(size >= n0)(cancelers.take.flatMap(_.succeed(()))) + _ <- ZIO.when(size >= n0)(cancelers.take.flatMap(_.succeedUnit)) _ <- cancelers.offer(canceler) raceIOs = ZIO.scopedWith { scope => @@ -2026,7 +2026,7 @@ object ZChannel { .flatMap(evaluatePull(_).race(canceler.await.interruptible)) } _ <- permits - .withPermit(latch.succeed(()) *> raceIOs) + .withPermit(latch.succeedUnit *> raceIOs) .interruptible .forkIn(childScope) _ <- latch.await diff --git a/streams/shared/src/main/scala/zio/stream/ZStream.scala b/streams/shared/src/main/scala/zio/stream/ZStream.scala index 96ef2cbdabc7..e5c61408e728 100644 --- a/streams/shared/src/main/scala/zio/stream/ZStream.scala +++ b/streams/shared/src/main/scala/zio/stream/ZStream.scala @@ -534,7 +534,7 @@ final class ZStream[-R, +E, +A] private (val channel: ZChannel[R, Any, Any, Any, ): ZChannel[R1, Any, Any, Any, E1, Chunk[A1], Unit] = { lazy val process: ZChannel[Any, Any, Any, Any, E1, Chunk[A1], Unit] = ZChannel.fromZIO(queue.take).flatMap { case (take, promise) => - ZChannel.fromZIO(promise.succeed(())) *> + ZChannel.fromZIO(promise.succeedUnit) *> take.fold( ZChannel.unit, error => ZChannel.refailCause(error), @@ -549,7 +549,7 @@ final class ZStream[-R, +E, +A] private (val channel: ZChannel[R, Any, Any, Any, for { queue <- scoped start <- Promise.make[Nothing, Unit] - _ <- start.succeed(()) + _ <- start.succeedUnit ref <- Ref.make(start) _ <- (channel >>> producer(queue, ref)).runScoped.forkScoped } yield consumer(queue) @@ -3320,7 +3320,7 @@ final class ZStream[-R, +E, +A] private (val channel: ZChannel[R, Any, Any, Any, .pipeTo(loop) .ensuring(queue.offer(Take.end).forkDaemon *> queue.awaitShutdown) *> ZChannel.unit ) - .merge(ZStream.execute((promise.succeed(()) *> right.run(sink)).ensuring(queue.shutdown)), HaltStrategy.Both) + .merge(ZStream.execute((promise.succeedUnit *> right.run(sink)).ensuring(queue.shutdown)), HaltStrategy.Both) } /** @@ -5656,14 +5656,14 @@ object ZStream extends ZStreamPlatformSpecificConstructors { Promise.make[Nothing, Unit].flatMap { p => ref.modify { case s @ Handoff.State.Full(_, notifyProducer) => (notifyProducer.await *> offer(a), s) - case Handoff.State.Empty(notifyConsumer) => (notifyConsumer.succeed(()) *> p.await, Handoff.State.Full(a, p)) + case Handoff.State.Empty(notifyConsumer) => (notifyConsumer.succeedUnit *> p.await, Handoff.State.Full(a, p)) }.flatten } def take(implicit trace: Trace): UIO[A] = Promise.make[Nothing, Unit].flatMap { p => ref.modify { - case Handoff.State.Full(a, notifyProducer) => (notifyProducer.succeed(()).as(a), Handoff.State.Empty(p)) + case Handoff.State.Full(a, notifyProducer) => (notifyProducer.succeedUnit.as(a), Handoff.State.Empty(p)) case s @ Handoff.State.Empty(notifyConsumer) => (notifyConsumer.await *> take, s) }.flatten } @@ -5671,7 +5671,7 @@ object ZStream extends ZStreamPlatformSpecificConstructors { def poll(implicit trace: Trace): UIO[Option[A]] = Promise.make[Nothing, Unit].flatMap { p => ref.modify { - case Handoff.State.Full(a, notifyProducer) => (notifyProducer.succeed(()).as(Some(a)), Handoff.State.Empty(p)) + case Handoff.State.Full(a, notifyProducer) => (notifyProducer.succeedUnit.as(Some(a)), Handoff.State.Empty(p)) case s @ Handoff.State.Empty(_) => (ZIO.succeed(None), s) }.flatten } From 5232c2df39b84a4171ac9c4339aeb0371af7f1f9 Mon Sep 17 00:00:00 2001 From: Jules Ivanic Date: Mon, 10 Feb 2025 17:37:32 +1100 Subject: [PATCH 2/3] fmt --- core/shared/src/main/scala/zio/Promise.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/shared/src/main/scala/zio/Promise.scala b/core/shared/src/main/scala/zio/Promise.scala index 5c09c26bd3d1..e2492a14836b 100644 --- a/core/shared/src/main/scala/zio/Promise.scala +++ b/core/shared/src/main/scala/zio/Promise.scala @@ -177,7 +177,8 @@ final class Promise[E, A] private ( ZIO.succeed(unsafe.succeed(a)(trace, Unsafe.unsafe)) /** - * Internally, you can use this method instead of calling `myPromise.succeed(())` + * Internally, you can use this method instead of calling + * `myPromise.succeed(())` * * It avoids the `Exit` allocation */ From a086d9332fa49fcde76a42213b49925bcc454ded Mon Sep 17 00:00:00 2001 From: Jules Ivanic Date: Mon, 10 Feb 2025 18:36:22 +1100 Subject: [PATCH 3/3] Review: Use `whenZIODiscard` when possible --- core/shared/src/main/scala/zio/Hub.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/core/shared/src/main/scala/zio/Hub.scala b/core/shared/src/main/scala/zio/Hub.scala index c5affe3d29b7..fded13b2e99f 100644 --- a/core/shared/src/main/scala/zio/Hub.scala +++ b/core/shared/src/main/scala/zio/Hub.scala @@ -160,10 +160,9 @@ object Hub { ZIO.fiberIdWith { fiberId => shutdownFlag.set(true) ZIO - .whenZIO(shutdownHook.succeedUnit) { + .whenZIODiscard(shutdownHook.succeedUnit) { scope.close(Exit.interrupt(fiberId)) *> strategy.shutdown } - .unit }.uninterruptible def size(implicit trace: Trace): UIO[Int] = ZIO.suspendSucceed { @@ -229,15 +228,14 @@ object Hub { ZIO.fiberIdWith { fiberId => shutdownFlag.set(true) ZIO - .whenZIO(shutdownHook.succeedUnit) { - ZIO.foreachPar(unsafePollAll(pollers))(_.interruptAs(fiberId)) *> + .whenZIODiscard(shutdownHook.succeedUnit) { + ZIO.foreachParDiscard(unsafePollAll(pollers))(_.interruptAs(fiberId)) *> ZIO.succeed { subscribers.remove(subscription -> pollers) subscription.unsubscribe() strategy.unsafeOnHubEmptySpace(hub, subscribers) } } - .unit }.uninterruptible def size(implicit trace: Trace): UIO[Int] = ZIO.suspendSucceed {