Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 76 additions & 4 deletions core-tests/shared/src/test/scala/zio/stm/TPriorityQueueSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,85 @@ import zio.test._
object TPriorityQueueSpec extends ZIOBaseSpec {

def spec = suite("TPriorityQueueSpec")(
testM("offer and take") {
checkM(Gen.listOf(Gen.anyInt)) { vs =>
testM("offerAll and takeAll") {
checkM(Gen.chunkOf(Gen.anyInt)) { as =>
val transaction = for {
queue <- TPriorityQueue.fromIterable(vs)
queue <- TPriorityQueue.empty[Int]
_ <- queue.offerAll(as)
values <- queue.takeAll
} yield values
assertM(transaction.commit)(equalTo(vs.sorted))
assertM(transaction.commit)(equalTo(as.sorted))
}
},
testM("removeIf") {
checkM(Gen.listOf(Gen.anyInt), Gen.function(Gen.boolean)) { (as, f) =>
val transaction = for {
queue <- TPriorityQueue.fromIterable(as)
_ <- queue.removeIf(f)
list <- queue.toList
} yield list
assertM(transaction.commit)(equalTo(as.filterNot(f).sorted))
}
},
testM("retainIf") {
checkM(Gen.listOf(Gen.anyInt), Gen.function(Gen.boolean)) { (as, f) =>
val transaction = for {
queue <- TPriorityQueue.fromIterable(as)
_ <- queue.retainIf(f)
list <- queue.toList
} yield list
assertM(transaction.commit)(equalTo(as.filter(f).sorted))
}
},
testM("take") {
checkM(Gen.listOf(Gen.anyInt)) { as =>
val transaction = for {
queue <- TPriorityQueue.fromIterable(as)
takes <- STM.collectAll(STM.replicate(as.length)(queue.take))
} yield takes
assertM(transaction.commit)(equalTo((as.sorted)))
}
},
testM("takeUpTo") {
val gen = for {
as <- Gen.chunkOf(Gen.int(1, 10))
n <- Gen.int(0, as.length)
} yield (as, n)
checkM(gen) {
case (as, n) =>
val transaction = for {
queue <- TPriorityQueue.fromIterable(as)
left <- queue.takeUpTo(n)
right <- queue.takeAll
} yield (left, right)
assertM(transaction.commit)(equalTo((as.sorted.take(n), as.sorted.drop(n))))
}
},
testM("toChunk") {
checkM(Gen.chunkOf(Gen.anyInt)) { as =>
val transaction = for {
queue <- TPriorityQueue.fromIterable(as)
list <- queue.toChunk
} yield list
assertM(transaction.commit)(equalTo(as.sorted))
}
},
testM("toList") {
checkM(Gen.listOf(Gen.anyInt)) { as =>
val transaction = for {
queue <- TPriorityQueue.fromIterable(as)
list <- queue.toList
} yield list
assertM(transaction.commit)(equalTo(as.sorted))
}
},
testM("toVector") {
checkM(Gen.vectorOf(Gen.anyInt)) { as =>
val transaction = for {
queue <- TPriorityQueue.fromIterable(as)
list <- queue.toVector
} yield list
assertM(transaction.commit)(equalTo(as.sorted))
}
}
)
Expand Down
179 changes: 106 additions & 73 deletions core/shared/src/main/scala/zio/stm/TPriorityQueue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,137 +16,168 @@

package zio.stm

import zio.Chunk
import scala.collection.immutable.SortedMap

import zio.stm.ZSTM.internal._
import zio.{ Chunk, ChunkBuilder }

/**
* A simple `TPriorityQueue` implementation. A `TPriorityQueue` contains values
* of type `V`. Each value is associated with a key of type `K` that an
* `Ordering` is defined on. Unlike a `TQueue`, `take` returns the highest
* priority value (the value that is first in the specified ordering) as
* opposed to the first value offered to the queue. The ordering that elements
* with the same priority will be taken from the queue is not guaranteed.
* of type `A` that an `Ordering` is defined on. Unlike a `TQueue`, `take`
* returns the highest priority value (the value that is first in the specified
* ordering) as opposed to the first value offered to the queue. The ordering
* that elements with the same priority will be taken from the queue is not
* guaranteed.
*/
final class TPriorityQueue[A] private (private val map: TMap[Int, A], private val ord: Ordering[A]) {
final class TPriorityQueue[A] private (private val ref: TRef[SortedMap[A, Int]]) extends AnyVal {

/**
* Offers the specified value to the queue with the specified priority.
* Offers the specified value to the queue.
*/
def offer(a: A): USTM[Unit] =
map.size.flatMap(n => map.put(n, a) *> bubbleUp(n))
ref.update(map => map + (a -> map.get(a).fold(1)(_ + 1)))

/**
* Offers all of the elements in the specified collection to the queue.
*/
def offerAll(values: Iterable[A]): USTM[Unit] =
ZSTM.foreach_(values)(offer)
ref.update(map => values.foldLeft(map)((map, a) => map + (a -> map.get(a).fold(1)(_ + 1))))

/**
* Peeks at the first value in the queue without removing it, retrying until
* a value is in the queue.
*/
def peek: USTM[A] =
peekOption.collect { case Some(a) => a }
new ZSTM((journal, _, _, _) =>
ref.unsafeGet(journal).headOption match {
case None => TExit.Retry
case Some((a, _)) => TExit.Succeed(a)
}
)

/**
* Peeks at the first value in the queue without removing it, returning
* `None` if there is not a value in the queue.
*/
def peekOption: USTM[Option[A]] =
map.get(0)
ref.modify(map => (map.headOption.map(_._1), map))

/**
* Removes all elements from the queue matching the specified predicate.
*/
def removeIf(f: A => Boolean): USTM[Unit] =
retainIf(!f(_))

/**
* Retains only elements from the queue matching the specified predicate.
*/
def retainIf(f: A => Boolean): USTM[Unit] =
ref.update(_.filter { case (a, _) => f(a) })

/**
* Returns the size of the queue.
*/
def size: USTM[Int] =
map.size
ref.modify(map => (map.values.sum, map))

/**
* Takes a value from the queue, retrying until a value is in the queue.
*/
def take: USTM[A] =
takeOption.collect { case Some(a) => a }
new ZSTM((journal, _, _, _) => {
val map = ref.unsafeGet(journal)
map.headOption match {
case None => TExit.Retry
case Some((a, n)) =>
ref.unsafeSet(journal, if (n == 1) map - a else map + (a -> (n - 1)))
TExit.Succeed(a)
}
})

/**
* Takes all values from the queue.
*/
def takeAll: USTM[List[A]] =
map.size.flatMap(n => STM.collectAll(STM.replicate(n)(take)))
def takeAll: USTM[Chunk[A]] =
ref.modify { map =>
val builder = ChunkBuilder.make[A]()
var updated = map
map.foreach {
case (a, n) =>
var i = 0
while (i < n) {
builder += a
i += 1
}
updated -= a
}
(builder.result, updated)
}

/**
* Takes up to the specified maximum number of elements from the queue.
*/
def takeUpTo(n: Int): USTM[Chunk[A]] =
ref.modify { map =>
val builder = ChunkBuilder.make[A]()
val iterator = map.iterator
var updated = map
var i = 0
while (iterator.hasNext && i < n) {
val (a, j) = iterator.next()
var k = 0
while (i < n && k < j) {
builder += a
i += 1
k += 1
}
if (k == j) updated -= a else updated += (a -> (j - k))
}
(builder.result, updated)
}

/**
* Takes a value from the queue, returning `None` if there is not a value in
* the queue.
*/
def takeOption: USTM[Option[A]] =
map.get(0).flatMap {
case None => STM.succeed(None)
case Some(v) =>
for {
size <- map.size
a <- map.get(size - 1)
_ <- map.delete(size - 1)
_ <- map.put(0, a.get)
_ <- bubbleDown(0)
} yield Some(v)
}
new ZSTM((journal, _, _, _) => {
val map = ref.unsafeGet(journal)
map.headOption match {
case None => TExit.Succeed(None)
case Some((a, n)) =>
ref.unsafeSet(journal, if (n == 1) map - a else map + (a -> (n - 1)))
TExit.Succeed(Some(a))
}
})

/**
* Collects all values into a chunk.
*/
def toChunk: USTM[Chunk[A]] =
takeAll.map(Chunk.fromIterable)
ref.modify { map =>
val builder = ChunkBuilder.make[A]()
map.foreach {
case (a, n) =>
var i = 0
while (i < n) {
builder += a
i += 1
}
}
(builder.result, map)
}

/**
* Collects all values into a list.
*/
def toList: USTM[List[A]] =
takeAll
toChunk.map(_.toList)

/**
* Collects all values into a vector.
*/
def toVector: USTM[Vector[A]] =
takeAll.map(_.toVector)

private def parent0(n: Int): Int =
if (n == 0) 0
else (n - 1) / 2

private def bubbleUp(n: Int): USTM[Unit] =
for {
child <- map.get(n)
parent <- map.get(parent0(n))
_ <- if (ord.gteq(child.get, parent.get)) STM.unit else swap(n, parent0(n)) *> bubbleUp(parent0(n))
} yield ()

private def bubbleDown(n: Int): USTM[Unit] =
size.flatMap { size =>
if (2 * n + 1 >= size) STM.unit
else if (2 * n + 2 == size)
for {
x <- map.get(n)
y <- map.get(2 * n + 1)
_ <- if (ord.gt(x.get, y.get)) swap(n, 2 * n + 1) *> bubbleDown(2 * n + 1) else STM.unit
} yield ()
else {
for {
parent <- map.get(n)
leftChild <- map.get(2 * n + 1)
rightChild <- map.get(2 * n + 2)
_ <- if (ord.lteq(parent.get, leftChild.get) && ord.lteq(parent.get, rightChild.get)) STM.unit
else if (ord.lteq(leftChild.get, rightChild.get))
swap(n, 2 * n + 1) *> bubbleDown(2 * n + 1)
else swap(n, 2 * n + 2) *> bubbleDown(2 * n + 2)
} yield ()
}
}

private def swap(i: Int, j: Int): USTM[Unit] =
for {
x <- map.get(i)
y <- map.get(j)
_ <- map.put(i, y.get)
_ <- map.put(j, x.get)
} yield ()
toChunk.map(_.toVector)
}

object TPriorityQueue {
Expand All @@ -155,13 +186,15 @@ object TPriorityQueue {
* Constructs a new empty `TPriorityQueue` with the specified `Ordering`.
*/
def empty[A](implicit ord: Ordering[A]): USTM[TPriorityQueue[A]] =
TMap.empty[Int, A].map(map => new TPriorityQueue(map, ord))
TRef.make(SortedMap.empty[A, Int]).map(ref => new TPriorityQueue(ref))

/**
* Makes a new `TPriorityQueue` initialized with provided iterable.
*/
def fromIterable[A](data: Iterable[A])(implicit ord: Ordering[A]): USTM[TPriorityQueue[A]] =
empty[A].flatMap(queue => queue.offerAll(data).as(queue))
TRef
.make(data.foldLeft(SortedMap.empty[A, Int])((map, a) => map + (a -> map.get(a).fold(1)(_ + 1))))
.map(ref => new TPriorityQueue(ref))

/**
* Makes a new `TPriorityQueue` that is initialized with specified values.
Expand Down