Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ object TReentrantLockSpec extends DefaultRunnableSpec {
reader2 <- lock.readLock.use(count => wlatch.succeed(()) as count).fork
_ <- wlatch.await
count <- reader2.join
} yield assert(count)(equalTo(2))
} yield assert(count)(equalTo(1))
} @@ timeout(10.seconds),
testM("1 write lock then 1 read lock, different fibers") {
for {
Expand All @@ -61,7 +61,7 @@ object TReentrantLockSpec extends DefaultRunnableSpec {
mlatch <- Promise.make[Nothing, Unit]
_ <- lock.writeLock.use(count => rlatch.succeed(()) *> wlatch.await as count).fork
_ <- rlatch.await
reader <- (mlatch.succeed(()) *> lock.readLock.use(ZIO.succeed(_))).fork
reader <- (mlatch.succeed(()) *> lock.readLock.use(ZIO.succeedNow(_))).fork
_ <- mlatch.await
locks <- (lock.readLocks zipWith lock.writeLocks)(_ + _).commit
option <- reader.poll.repeat(pollSchedule)
Expand Down
12 changes: 6 additions & 6 deletions core/shared/src/main/scala/zio/stm/TMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ final class TMap[K, V] private (
tSize.unsafeSet(journal, currSize - 1)
}

TExit.Succeed(())
TExit.unit
})

/**
Expand Down Expand Up @@ -183,7 +183,7 @@ final class TMap[K, V] private (
}
}

TExit.Succeed(())
TExit.unit
})
}

Expand Down Expand Up @@ -218,7 +218,7 @@ final class TMap[K, V] private (

tSize.unsafeSet(journal, newSize)

TExit.Succeed(())
TExit.unit
})

/**
Expand Down Expand Up @@ -252,7 +252,7 @@ final class TMap[K, V] private (

tSize.unsafeSet(journal, newSize)

TExit.Succeed(())
TExit.unit
})

/**
Expand Down Expand Up @@ -336,7 +336,7 @@ final class TMap[K, V] private (

tSize.unsafeSet(journal, newSize)

TExit.Succeed(())
TExit.unit
})

/**
Expand Down Expand Up @@ -372,7 +372,7 @@ final class TMap[K, V] private (
}

tSize.unsafeSet(journal, newSize)
TExit.Succeed(())
TExit.unit
})
}
}
Expand Down
147 changes: 88 additions & 59 deletions core/shared/src/main/scala/zio/stm/TReentrantLock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ package zio.stm

import TReentrantLock._

import zio.{ Fiber, Managed, UIO, UManaged }
import zio.stm.ZSTM.internal.TExit
import zio.{ Fiber, Managed, UManaged }

/**
* A `TReentrantLock` is a reentrant read/write lock. Multiple readers may all
Expand All @@ -36,33 +37,35 @@ import zio.{ Fiber, Managed, UIO, UManaged }
* allows this structure to be composed into more complicated concurrent
* structures that are consumed from effectful code.
*/
final class TReentrantLock private (data: TRef[Either[ReadLock, WriteLock]]) {
final class TReentrantLock private (data: TRef[LockState]) {

private val tExitOne = TExit.Succeed(1)

/**
* Acquires a read lock. The transaction will suspend until no other fiber
* is holding a write lock. Succeeds with the number of read locks held by
* this fiber.
* is holding a write lock. Succeeds with the number of read locks held by this fiber.
*/
lazy val acquireRead: USTM[Int] =
STM.fiberId.flatMap(fiberId => adjustRead(fiberId, 1))
lazy val acquireRead: USTM[Int] = adjustRead(1)

/**
* Acquires a write lock. The transaction will suspend until no other
* fibers are holding read or write locks. Succeeds with the number of
* write locks held by this fiber.
*/
lazy val acquireWrite: USTM[Int] =
for {
fiberId <- STM.fiberId
w <- data.get.collect {
case Left(readLock) if (readLock.noOtherHolder(fiberId)) =>
WriteLock(1, readLock.readLocks(fiberId), fiberId)
lazy val acquireWrite: USTM[Int] = new ZSTM((journal, fiberId, _, _) =>
data.unsafeGet(journal) match {

case readLock: ReadLock if readLock.noOtherHolder(fiberId) =>
data.unsafeSet(journal, WriteLock(1, readLock.readLocks(fiberId), fiberId))
tExitOne

case Right(WriteLock(n, m, `fiberId`)) =>
WriteLock(n + 1, m, fiberId)
}
_ <- data.set(Right(w))
} yield w.writeLocks
case WriteLock(n, m, `fiberId`) =>
data.unsafeSet(journal, WriteLock(n + 1, m, fiberId))
TExit.Succeed(n + 1)

case _ => TExit.Retry
}
)

/**
* Just a convenience method for applications that only need reentrant locks,
Expand All @@ -85,40 +88,47 @@ final class TReentrantLock private (data: TRef[Either[ReadLock, WriteLock]]) {
Managed.make(acquireRead.commit)(_ => releaseRead.commit)

/**
* Retrieves the number of acquired read locks.
* Retrieves the total number of acquired read locks.
*/
def readLocks: USTM[Int] = data.get.map(_.fold(_.readLocks, _.readLocks))
def readLocks: USTM[Int] = data.get.map(_.readLocks)

/**
* Retrieves the number of acquired read locks for this fiber.
*/
def fiberReadLocks: USTM[Int] =
new ZSTM((journal, fiberId, _, _) => TExit.Succeed(data.unsafeGet(journal).readLocks(fiberId)))

/**
* Retrieves the number of acquired write locks for this fiber.
*/
def fiberWriteLocks: USTM[Int] =
new ZSTM((journal, fiberId, _, _) => TExit.Succeed(data.unsafeGet(journal).writeLocks(fiberId)))

/**
* Determines if any fiber has a read lock.
*/
def readLocked: USTM[Boolean] = readLocks.map(_ > 0)
def readLocked: USTM[Boolean] = data.get.map(_.readLocks > 0)

/**
* Releases a read lock held by this fiber. Succeeds with the outstanding
* number of read locks held by this fiber.
*/
lazy val releaseRead: USTM[Int] =
STM.fiberId.flatMap(fiberId => adjustRead(fiberId, -1))
lazy val releaseRead: USTM[Int] = adjustRead(-1)

/**
* Releases a write lock held by this fiber. Succeeds with the outstanding
* number of write locks held by this fiber.
*/
lazy val releaseWrite: USTM[Int] =
STM.fiberId.flatMap(fiberId =>
data.modify {
case Right(WriteLock(1, m, `fiberId`)) =>
0 -> Left(ReadLock(fiberId, m))

case Right(WriteLock(n, m, `fiberId`)) if n > 1 =>
val newCount = n - 1

newCount -> Right(WriteLock(newCount, m, fiberId))

case s => die(s"Defect: Fiber ${fiberId} releasing write lock it does not hold: ${s}")
}
)
lazy val releaseWrite: USTM[Int] = new ZSTM((journal, fiberId, _, _) => {
val res = data.unsafeGet(journal) match {
case WriteLock(1, m, `fiberId`) => ReadLock(fiberId, m)
case WriteLock(n, m, `fiberId`) if n > 1 =>
WriteLock(n - 1, m, fiberId)
case s => die(s"Defect: Fiber ${fiberId} releasing write lock it does not hold: ${s}")
}
data.unsafeSet(journal, res)
TExit.Succeed(res.writeLocks(fiberId))
})

/**
* Obtains a write lock in a managed context.
Expand All @@ -129,44 +139,66 @@ final class TReentrantLock private (data: TRef[Either[ReadLock, WriteLock]]) {
/**
* Determines if a write lock is held by some fiber.
*/
def writeLocked: USTM[Boolean] = writeLocks.map(_ > 0)
def writeLocked: USTM[Boolean] = data.get.map(_.writeLocks > 0)

/**
* Computes the number of write locks held by fibers.
*/
def writeLocks: USTM[Int] = data.get.map(_.fold(_ => 0, _.writeLocks))

private def adjustRead(fiberId: Fiber.Id, delta: Int): USTM[Int] =
(data.get.collect {
case Left(readLock) => Left(readLock.adjust(fiberId, delta))
case Right(wl @ WriteLock(w, r, `fiberId`)) =>
val newTotal = r + delta

if (newTotal < 0) die(s"Defect: Fiber ${fiberId} releasing read locks it does not hold: ${wl}")
else Right(WriteLock(w, newTotal, fiberId))
}.flatMap(data.set(_)) *> data.get.map(_.fold(_.readLocks, _.readLocks)))
def writeLocks: USTM[Int] = data.get.map(_.writeLocks)

private def adjustRead(delta: Int): USTM[Int] =
new ZSTM((journal, fiberId, _, _) =>
data.unsafeGet(journal) match {

case readLock: ReadLock =>
val res = readLock.adjust(fiberId, delta)
data.unsafeSet(journal, res)
TExit.Succeed(res.readLocks(fiberId))

case WriteLock(w, r, `fiberId`) =>
val newTotal = r + delta
if (newTotal < 0)
die(s"Defect: Fiber ${fiberId} releasing read locks it does not hold, newTotal: $newTotal")
else
data.unsafeSet(journal, WriteLock(w, newTotal, fiberId))
TExit.Succeed(newTotal)

case _ => TExit.Retry //another fiber is holding a write lock
}
)
}
object TReentrantLock {

private[stm] sealed trait LockState {
def readLocks: Int
def readLocks(fiberId: Fiber.Id): Int
val writeLocks: Int
def writeLocks(fiberId: Fiber.Id): Int
}

/**
* This data structure describes the state of the lock when a single fiber
* has a write lock. The fiber has an identity, and may also have acquired
* a certain number of read locks.
*/
private[stm] final case class WriteLock(writeLocks: Int, readLocks: Int, fiberId: Fiber.Id)
private[stm] final case class WriteLock(writeLocks: Int, readLocks: Int, fiberId: Fiber.Id) extends LockState {
override def readLocks(fiberId0: Fiber.Id): Int = if (fiberId0 == fiberId) readLocks else 0

override def writeLocks(fiberId0: Fiber.Id): Int = if (fiberId0 == fiberId) writeLocks else 0
}

/**
* This data structure describes the state of the lock when multiple fibers
* have acquired read locks. The state is tracked as a map from fiber identity
* to number of read locks acquired by the fiber. This level of detail permits
* upgrading a read lock to a write lock.
*/
private[stm] final class ReadLock private (readers: Map[Fiber.Id, Int]) {
private[stm] final class ReadLock(readers: Map[Fiber.Id, Int]) extends LockState {

/**
* Computes the total number of read locks acquired.
*/
def readLocks: Int = readers.values.sum
lazy val readLocks: Int = readers.values.sum

/**
* Determines if there is no other holder of read locks aside from the
Expand All @@ -180,7 +212,7 @@ object TReentrantLock {
/**
* Computes the number of read locks held by the specified fiber id.
*/
def readLocks(fiberId: Fiber.Id): Int = readers.get(fiberId).getOrElse(0)
def readLocks(fiberId: Fiber.Id): Int = readers.getOrElse(fiberId, 0)

/**
* Adjusts the number of read locks held by the specified fiber id.
Expand All @@ -197,6 +229,9 @@ object TReentrantLock {
)
}

override val writeLocks: Int = 0

override def writeLocks(fiberId: Fiber.Id): Int = 0
}
private[stm] object ReadLock {

Expand All @@ -217,13 +252,7 @@ object TReentrantLock {
* Makes a new reentrant read/write lock.
*/
def make: USTM[TReentrantLock] =
TRef.make[Either[ReadLock, WriteLock]](Left(ReadLock.empty)).map(tref => new TReentrantLock(tref))

/**
* Makes a new reentrant read/write lock.
*/
val makeCommit: UIO[TReentrantLock] =
make.commit
TRef.make[LockState](ReadLock.empty).map(new TReentrantLock(_))

private def die(message: String): Nothing =
throw new RuntimeException(message)
Expand Down
2 changes: 2 additions & 0 deletions core/shared/src/main/scala/zio/stm/ZSTM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,8 @@ object ZSTM {

sealed trait TExit[+A, +B] extends Serializable with Product
object TExit {
val unit: TExit[Nothing, Unit] = Succeed(())

final case class Fail[+A](value: A) extends TExit[A, Nothing]
final case class Succeed[+B](value: B) extends TExit[Nothing, B]
case object Retry extends TExit[Nothing, Nothing]
Expand Down
4 changes: 2 additions & 2 deletions core/shared/src/main/scala/zio/stm/ZTRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ object ZTRef {
new ZSTM((journal, _, _, _) => {
val entry = getOrMakeEntry(journal)
entry.unsafeSet(a)
TExit.Succeed(())
TExit.unit
})

/**
Expand Down Expand Up @@ -277,7 +277,7 @@ object ZTRef {
val entry = getOrMakeEntry(journal)
val newValue = f(entry.unsafeGet[A])
entry.unsafeSet(newValue)
TExit.Succeed(())
TExit.unit
})

/**
Expand Down