diff --git a/core-tests/shared/src/test/scala/zio/ZIOSpec.scala b/core-tests/shared/src/test/scala/zio/ZIOSpec.scala index d4221a4a6c59..4ef308bb9438 100644 --- a/core-tests/shared/src/test/scala/zio/ZIOSpec.scala +++ b/core-tests/shared/src/test/scala/zio/ZIOSpec.scala @@ -65,6 +65,34 @@ object ZIOSpec assertM(ZIO.fail(42).raceAll(List(IO.succeed(24) <* live(ZIO.sleep(100.millis)))), equalTo(24)) } ), + suite("foreachPar")( + testM("runs effects in parallel") { + assertM(for { + p <- Promise.make[Nothing, Unit] + _ <- UIO.foreachPar(List(UIO.never, p.succeed(())))(a => a).fork + _ <- p.await + } yield true, isTrue) + }, + testM("propagates error") { + val ints = List(1, 2, 3, 4, 5, 6) + val odds = ZIO.foreachPar(ints) { n => + if (n % 2 != 0) ZIO.succeed(n) else ZIO.fail("not odd") + } + assertM(odds.flip, equalTo("not odd")) + }, + testM("interrupts effects on first failure") { + for { + ref <- Ref.make(false) + promise <- Promise.make[Nothing, Unit] + actions = List( + ZIO.never, + ZIO.succeed(1), + ZIO.fail("C"), + promise.await *> ref.set(true) + ) + e <- ZIO.foreachPar(actions)(a => a).flip + v <- ref.get + } yield assert(e, equalTo("C")) && assert(v, isFalse) suite("option")( testM("return success in Some") { assertM(ZIO.succeed(11).option, equalTo[Option[Int]](Some(11))) diff --git a/core/shared/src/main/scala/zio/ZIO.scala b/core/shared/src/main/scala/zio/ZIO.scala index 60db6ae5f3bb..fd838cc0d015 100644 --- a/core/shared/src/main/scala/zio/ZIO.scala +++ b/core/shared/src/main/scala/zio/ZIO.scala @@ -2022,11 +2022,41 @@ private[zio] trait ZIOFunctions extends Serializable { * * For a sequential version of this method, see `foreach`. */ - final def foreachPar[R, E, A, B](as: Iterable[A])(fn: A => ZIO[R, E, B]): ZIO[R, E, List[B]] = - as.foldRight[ZIO[R, E, List[B]]](effectTotal(Nil)) { (a, io) => - fn(a).zipWithPar(io)((b, bs) => b :: bs) - } - .refailWithTrace + final def foreachPar[R, E, A, B](as: Iterable[A])(fn: A => ZIO[R, E, B]): ZIO[R, E, List[B]] = { + def arbiter( + fibers: Iterable[Fiber[E, _]], + promise: Promise[E, List[B]], + buffer: Array[B], + todo: Ref[Int], + idx: Int + )(res: Exit[E, B]): ZIO[R, Nothing, Unit] = + res.foldM[R, Nothing, Unit]( + e => promise.halt(e).unit *> Fiber.interruptAll(fibers), + a => + todo.modify { t => + buffer.update(idx, a) + if (t == 1) promise.succeed(buffer.toList).unit -> 0 else UIO.unit -> (t - 1) + }.flatten + ) + + (for { + size <- UIO.effectTotal(as.size) + todo <- Ref.make(size) + buffer <- UIO.effectTotal(new Array[Any](size).asInstanceOf[Array[B]]) + promise <- Promise.make[E, List[B]] + c <- ZIO.uninterruptibleMask { restore => + for { + as <- ZIO.traverse(as)(a => ZIO.interruptible(fn(a)).fork) + _ <- ZIO.traverse_(as.zipWithIndex) { + case (f, idx) => + f.await.flatMap(arbiter(as, promise, buffer, todo, idx)).fork + } + _ <- promise.succeed(Nil).when(as.isEmpty) + c <- restore(promise.await).onInterrupt(promise.interrupt *> Fiber.interruptAll(as)) + } yield c + } + } yield c).refailWithTrace + } /** * Applies the function `f` to each element of the `Iterable[A]` and runs @@ -2034,17 +2064,36 @@ private[zio] trait ZIOFunctions extends Serializable { * * For a sequential version of this method, see `foreach_`. */ - final def foreachPar_[R, E, A](as: Iterable[A])(f: A => ZIO[R, E, _]): ZIO[R, E, Unit] = - ZIO - .succeed(as.iterator) - .flatMap { i => - def loop(a: A): ZIO[R, E, Unit] = - if (i.hasNext) f(a).zipWithPar(loop(i.next))((_, _) => ()) - else f(a).unit - if (i.hasNext) loop(i.next) - else ZIO.unit - } - .refailWithTrace + final def foreachPar_[R, E, A](as: Iterable[A])(f: A => ZIO[R, E, _]): ZIO[R, E, Unit] = { + def arbiter( + fibers: Iterable[Fiber[E, _]], + promise: Promise[E, Unit], + todo: Ref[Int] + )(res: Exit[E, _]): ZIO[R, Nothing, Unit] = + res.foldM[R, Nothing, Unit]( + e => promise.halt(e).unit *> Fiber.interruptAll(fibers), + _ => + todo.modify { t => + if (t == 1) promise.succeed(()).unit -> 0 else UIO.unit -> (t - 1) + }.flatten + ) + + (for { + size <- UIO.effectTotal(as.size) + todo <- Ref.make(size) + promise <- Promise.make[E, Unit] + c <- ZIO.uninterruptibleMask { restore => + for { + as <- ZIO.traverse(as)(a => ZIO.interruptible(f(a)).fork) + _ <- ZIO.traverse_(as) { f => + f.await.flatMap(arbiter(as, promise, todo)).fork + } + _ <- promise.succeed(()).when(as.isEmpty) + c <- restore(promise.await).onInterrupt(promise.interrupt *> Fiber.interruptAll(as)) + } yield c + } + } yield c).refailWithTrace + } /** * Applies the function `f` to each element of the `Iterable[A]` in parallel, @@ -2092,12 +2141,7 @@ private[zio] trait ZIOFunctions extends Serializable { * composite fiber that produces a list of their results, in order. */ final def forkAll[R, E, A](as: Iterable[ZIO[R, E, A]]): ZIO[R, Nothing, Fiber[E, List[A]]] = - as.foldRight[ZIO[R, Nothing, Fiber[E, List[A]]]](succeed(Fiber.succeed[E, List[A]](List()))) { (aIO, asFiberIO) => - asFiberIO.zip(aIO.fork).map { - case (asFiber, aFiber) => - asFiber.zipWith(aFiber)((as, a) => a :: as) - } - } + foreachPar(as)(v => v).fork /** * Returns an effect that forks all of the specified values, and returns a @@ -2258,7 +2302,7 @@ private[zio] trait ZIOFunctions extends Serializable { final def mergeAllPar[R, E, A, B]( in: Iterable[ZIO[R, E, A]] )(zero: B)(f: (B, A) => B): ZIO[R, E, B] = - in.foldLeft[ZIO[R, E, B]](succeed[B](zero))((acc, a) => acc.zipPar(a).map(f.tupled)).refailWithTrace + collectAllPar(in).map(_.foldLeft(zero)(f)).refailWithTrace /** * Returns an effect with the empty value. @@ -2308,8 +2352,8 @@ private[zio] trait ZIOFunctions extends Serializable { final def reduceAllPar[R, R1 <: R, E, A](a: ZIO[R, E, A], as: Iterable[ZIO[R1, E, A]])( f: (A, A) => A ): ZIO[R1, E, A] = - as.foldLeft[ZIO[R1, E, A]](a) { (l, r) => - l.zipPar(r).map(f.tupled) + a.zipPar(collectAllPar(as)).map { + case (head, tail) => tail.fold(head)(f) } /**