From 66cfa25876ad01a577d7d39aa32460c87414442d Mon Sep 17 00:00:00 2001 From: IgorDorokhov Date: Sat, 26 Apr 2025 13:34:18 -0400 Subject: [PATCH] add tryAcquire to Semaphore --- .../src/test/scala/zio/SemaphoreSpec.scala | 46 +++++++++++ .../test/scala/zio/stm/TSemaphoreSpec.scala | 77 +++++++++++++++++++ .../shared/src/main/scala/zio/Semaphore.scala | 39 +++++++++- .../src/main/scala/zio/stm/TSemaphore.scala | 37 +++++++++ 4 files changed, 198 insertions(+), 1 deletion(-) diff --git a/core-tests/shared/src/test/scala/zio/SemaphoreSpec.scala b/core-tests/shared/src/test/scala/zio/SemaphoreSpec.scala index 8fd5f58c48b0..78076c00ef76 100644 --- a/core-tests/shared/src/test/scala/zio/SemaphoreSpec.scala +++ b/core-tests/shared/src/test/scala/zio/SemaphoreSpec.scala @@ -32,6 +32,52 @@ object SemaphoreSpec extends ZIOBaseSpec { permits <- semaphore.available } yield assertTrue(permits == 2L) }, + test("tryWithPermits acquires and releases same number of permits") { + for { + sem <- Semaphore.make(3L) + ans <- sem.tryWithPermits(2L)(ZIO.unit) + permits <- sem.available + } yield assertTrue(permits == 3L && ans.isDefined) + }, + test("tryWithPermits if 0 permits requested") { + for { + sem <- Semaphore.make(3L) + ans <- sem.tryWithPermits(0L)(ZIO.succeed("I got executed")) + permits <- sem.available + } yield assertTrue(permits == 3L && ans.contains("I got executed")) + }, + test("tryWithPermits returns None if no permits available") { + for { + sem <- Semaphore.make(3L) + ans <- sem.tryWithPermits(4L)(ZIO.succeed("Shouldn't get executed")) + permits <- sem.available + } yield assertTrue(permits == 3L && ans.isEmpty) + }, + test("tryWithPermit acquires and releases same number of permits") { + for { + sem <- Semaphore.make(3L) + ans <- sem.tryWithPermit(ZIO.unit) + permits <- sem.available + } yield assertTrue(permits == 3L && ans.isDefined) + }, + test("tryWithPermits fails if requested permits in negative number") { + for { + sem <- Semaphore.make(3L) + ans <- sem.tryWithPermits(-1L)(ZIO.unit).exit + } yield assert(ans)(dies(isSubtype[IllegalArgumentException](anything))) + }, + test("tryWithPermits restores permits after failure") { + for { + sem <- Semaphore.make(3L) + failure = ZIO.fail("exception") + result <- sem.tryWithPermits(2L)(failure).exit + permits <- sem.available + } yield assertTrue( + permits == 3L, + result.isFailure, + result == Exit.fail("exception") + ) + }, test("awaiting returns the count of waiting fibers") { for { semaphore <- Semaphore.make(1) diff --git a/core-tests/shared/src/test/scala/zio/stm/TSemaphoreSpec.scala b/core-tests/shared/src/test/scala/zio/stm/TSemaphoreSpec.scala index c3c1d4cf01cb..1452a2c7bfb1 100644 --- a/core-tests/shared/src/test/scala/zio/stm/TSemaphoreSpec.scala +++ b/core-tests/shared/src/test/scala/zio/stm/TSemaphoreSpec.scala @@ -137,6 +137,83 @@ object TSemaphoreSpec extends ZIOBaseSpec { assertTrue(remaining == 3L) } } + ), + suite("tryAcquire, tryAcquireN, tryWithPermit and tryWithPermits")( + test("tryAcquire should succeed when a permit is available") { + for { + sem <- TSemaphore.makeCommit(1L) + res <- sem.tryAcquire.commit + } yield assert(res)(isTrue) + }, + test("tryAcquire should fail when no permits are available") { + for { + sem <- TSemaphore.makeCommit(0L) + res <- sem.tryAcquire.commit + } yield assert(res)(isFalse) + }, + test("tryAcquire should decrease the permit count when successful") { + for { + sem <- TSemaphore.makeCommit(1L) + _ <- sem.tryAcquire.commit + avail <- sem.available.commit + } yield assert(avail)(equalTo(0L)) + }, + test("tryAcquireN should acquire permits if enough are available") { + for { + sem <- TSemaphore.makeCommit(5L) + res <- sem.tryAcquireN(3L).commit + } yield assert(res)(isTrue) + }, + test("tryAcquireN should fail if not enough permits are available") { + for { + sem <- TSemaphore.makeCommit(2L) + res <- sem.tryAcquireN(3L).commit + } yield assert(res)(isFalse) + }, + test("tryAcquireN should decrease the permit count when successful") { + for { + sem <- TSemaphore.makeCommit(5L) + _ <- sem.tryAcquireN(3L).commit + avail <- sem.available.commit + } yield assert(avail)(equalTo(2L)) + }, + test("tryAcquireN should not change permit count when unsuccessful") { + for { + sem <- TSemaphore.makeCommit(2L) + _ <- sem.tryAcquireN(3L).commit + avail <- sem.available.commit + } yield assert(avail)(equalTo(2L)) + }, + test("tryWithPermits should acquire a permit and release it") { + for { + sem <- TSemaphore.makeCommit(2L) + result <- sem.tryWithPermits(1L)(ZIO.succeed(2)) + avail <- sem.available.commit + } yield assertTrue(result.contains(2) && avail == 2L) + }, + test("tryWithPermits should return None if no permits available") { + for { + sem <- TSemaphore.makeCommit(0L) + result <- sem.tryWithPermits(1L)(ZIO.succeed(2)) + avail <- sem.available.commit + } yield assertTrue(result.isEmpty && avail == 0L) + }, + test( + "tryWithPermits should return None if requested amount of permits is greater than available amount of permits" + ) { + for { + sem <- TSemaphore.makeCommit(3L) + result <- sem.tryWithPermits(5L)(ZIO.succeed(2)) + avail <- sem.available.commit + } yield assertTrue(result.isEmpty && avail == 3L) + }, + test("tryWithPermit should acquire a permit and release it") { + for { + sem <- TSemaphore.makeCommit(3L) + result <- sem.tryWithPermit(ZIO.succeed(2)) + avail <- sem.available.commit + } yield assertTrue(result.contains(2) && avail == 3L) + } ) ) diff --git a/core/shared/src/main/scala/zio/Semaphore.scala b/core/shared/src/main/scala/zio/Semaphore.scala index fbad74f692c1..1dacb46b0eb8 100644 --- a/core/shared/src/main/scala/zio/Semaphore.scala +++ b/core/shared/src/main/scala/zio/Semaphore.scala @@ -45,6 +45,20 @@ sealed trait Semaphore extends Serializable { */ def awaiting(implicit trace: Trace): UIO[Long] = ZIO.succeed(0L) + /** + * Executes the effect, acquiring a permit if available and releasing it after + * execution. Returns `None` if no permits were available. + */ + final def tryWithPermit[R, E, A](zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, Option[A]] = + tryWithPermits(1L)(zio) + + /** + * Executes the effect, acquiring `n` permits if available and releasing them + * after execution. Returns `None` if no permits were available. + */ + def tryWithPermits[R, E, A](n: Long)(zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, Option[A]] = + ZIO.none + /** * Executes the specified workflow, acquiring a permit immediately before the * workflow begins execution and releasing it immediately after the workflow @@ -71,6 +85,7 @@ sealed trait Semaphore extends Serializable { * permits and releasing them when the scope is closed. */ def withPermitsScoped(n: Long)(implicit trace: Trace): ZIO[Scope, Nothing, Unit] + } object Semaphore { @@ -110,13 +125,35 @@ object Semaphore { def withPermitsScoped(n: Long)(implicit trace: Trace): ZIO[Scope, Nothing, Unit] = ZIO.acquireRelease(reserve(n))(_.release).flatMap(_.acquire) + override def tryWithPermits[R, E, A](n: Long)(zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, Option[A]] = + ZIO.acquireReleaseWith(tryReserve(n)) { + case Some(reservation) => reservation.release + case _ => Exit.unit + } { + case _: Some[?] => zio.asSome + case _ => Exit.none + } + case class Reservation(acquire: UIO[Unit], release: UIO[Any]) + object Reservation { + private[zio] val zero = Reservation(ZIO.unit, ZIO.unit) + } + + def tryReserve(n: Long)(implicit trace: Trace): UIO[Option[Reservation]] = + if (n < 0) ZIO.die(new IllegalArgumentException(s"Unexpected negative `$n` permits requested.")) + else if (n == 0L) ZIO.succeed(Some(Reservation.zero)) + else + ref.modify { + case Right(permits) if permits >= n => + Some(Reservation(ZIO.unit, releaseN(n))) -> Right(permits - n) + case other => None -> other + } def reserve(n: Long)(implicit trace: Trace): UIO[Reservation] = if (n < 0) ZIO.die(new IllegalArgumentException(s"Unexpected negative `$n` permits requested.")) else if (n == 0L) - ZIO.succeedNow(Reservation(ZIO.unit, ZIO.unit)) + ZIO.succeed(Reservation.zero) else Promise.make[Nothing, Unit].flatMap { promise => ref.modify { diff --git a/core/shared/src/main/scala/zio/stm/TSemaphore.scala b/core/shared/src/main/scala/zio/stm/TSemaphore.scala index 5731cf404e7b..3ca0e810ad84 100644 --- a/core/shared/src/main/scala/zio/stm/TSemaphore.scala +++ b/core/shared/src/main/scala/zio/stm/TSemaphore.scala @@ -108,6 +108,43 @@ final class TSemaphore private (val permits: TRef[Long]) extends Serializable { permits.unsafeSet(journal, current + n) } + /** + * Tries to acquire a single permit in a transactional context. Returns `true` + * if the permit was acquired, otherwise `false`. + */ + def tryAcquire: USTM[Boolean] = tryAcquireN(1L) + + /** + * Tries to acquire the specified number of permits in a transactional + * context. Returns `true` if the permits were acquired, otherwise `false`. + */ + def tryAcquireN(n: Long): USTM[Boolean] = + ZSTM.Effect { (journal, _, _) => + assertNonNegative(n) + + val available: Long = permits.unsafeGet(journal) + if (available >= n) { + permits.unsafeSet(journal, available - n) + true + } else false + } + + /** + * Executes the specified effect, acquiring `1` permit if available and + * releasing them after execution. Returns `None` if no permits were + * available. + */ + def tryWithPermit[R, E, A](zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, Option[A]] = + tryWithPermits(1L)(zio) + + /** + * Executes the specified effect, acquiring `n` permits if available and + * releasing them after execution. Returns `None` if no permits were + * available. + */ + def tryWithPermits[R, E, A](n: Long)(zio: ZIO[R, E, A])(implicit trace: Trace): ZIO[R, E, Option[A]] = + ZSTM.acquireReleaseWith(tryAcquireN(n))(releaseN(n).commit.whenDiscard(_))(zio.when(_)) + /** * Executes the specified effect, acquiring a permit immediately before the * effect begins execution and releasing it immediately after the effect