diff --git a/test-tests/shared/src/test/scala/zio/test/mock/ParallelMockSpec.scala b/test-tests/shared/src/test/scala/zio/test/mock/ParallelMockSpec.scala new file mode 100644 index 000000000000..a72b1c10ee2f --- /dev/null +++ b/test-tests/shared/src/test/scala/zio/test/mock/ParallelMockSpec.scala @@ -0,0 +1,33 @@ +package zio.test.mock + +import zio.ZIO +import zio.test._ +import zio.test.mock.module.{ImpureModule, ImpureModuleMock} + +object ParallelMockSpec extends ZIOBaseSpec { + + import Assertion._ + import Expectation._ + + def spec: Spec[Any, TestFailure[Any], TestSuccess] = + suite("ParallelMockSpec")( + testM("Count calls for the same expectation") { + val mock = ImpureModuleMock.SingleParam(equalTo(1), value("r1")).repeats(100 to 100) + val app = ZIO.collectAllPar(Vector.fill(100)(ImpureModule.singleParam(1))).provideLayer(mock) + assertM(app)(hasSize[String](equalTo(100)) && hasSameElementsDistinct[String](Seq("r1"))) + }, + testM("Collect calls for all expectations") { + val params = 1 to 100 + val mock = + params + .map(i => ImpureModuleMock.SingleParam(equalTo(i), value(s"r$i"))) + .reduce(_ && _) + + val app = ZIO.collectAllPar(params.map(i => ImpureModule.singleParam(i))).provideLayer(mock) + + val expected = params.map(i => s"r$i") + + assertM(app)(hasSameElements(expected)) + } + ) +} diff --git a/test/shared/src/main/scala/zio/test/mock/internal/MockState.scala b/test/shared/src/main/scala/zio/test/mock/internal/MockState.scala index b39061a9b667..642271a8bc8b 100644 --- a/test/shared/src/main/scala/zio/test/mock/internal/MockState.scala +++ b/test/shared/src/main/scala/zio/test/mock/internal/MockState.scala @@ -24,18 +24,16 @@ import zio.{Has, Ref, UIO, ZIO} */ private[mock] final case class MockState[R <: Has[_]]( expectationRef: Ref[Expectation[R]], - callsCountRef: Ref[Int], - failedMatchesRef: Ref[List[InvalidCall]] + callsCountRef: Ref[Int] ) private[mock] object MockState { def make[R <: Has[_]](trunk: Expectation[R]): UIO[MockState[R]] = for { - expectationRef <- Ref.make[Expectation[R]](trunk) - callsCountRef <- Ref.make[Int](0) - failedMatchesRef <- Ref.make[List[InvalidCall]](List.empty) - } yield MockState[R](expectationRef, callsCountRef, failedMatchesRef) + expectationRef <- Ref.make[Expectation[R]](trunk) + callsCountRef <- Ref.make[Int](0) + } yield MockState[R](expectationRef, callsCountRef) def checkUnmetExpectations[R <: Has[_]](state: MockState[R]): ZIO[Any, Nothing, Any] = state.expectationRef.get diff --git a/test/shared/src/main/scala/zio/test/mock/internal/ProxyFactory.scala b/test/shared/src/main/scala/zio/test/mock/internal/ProxyFactory.scala index ff8979d7a333..891f3ff30785 100644 --- a/test/shared/src/main/scala/zio/test/mock/internal/ProxyFactory.scala +++ b/test/shared/src/main/scala/zio/test/mock/internal/ProxyFactory.scala @@ -18,7 +18,7 @@ package zio.test.mock.internal import zio.test.Assertion import zio.test.mock.{Capability, Expectation, Proxy} -import zio.{Has, IO, Tag, UIO, ULayer, ZIO, ZLayer} +import zio.{Has, IO, Tag, ULayer, ZIO, ZLayer} import scala.util.Try @@ -36,10 +36,17 @@ object ProxyFactory { def mockProxy[R <: Has[_]: Tag](state: MockState[R]): ULayer[Has[Proxy]] = ZLayer.succeed(new Proxy { def invoke[RIn <: Has[_], ROut, I, E, A](invoked: Capability[RIn, I, E, A], args: I): ZIO[ROut, E, A] = { - def findMatching(scopes: List[Scope[R]]): UIO[Matched[R, E, A]] = { + sealed trait MatchResult + object MatchResult { + case object UnexpectedCall extends MatchResult + case class Success(value: Matched[R, E, A]) extends MatchResult + case class Failure(failures: List[InvalidCall]) extends MatchResult + } + + def findMatching(scopes: List[Scope[R]], failedMatches: List[InvalidCall]): MatchResult = { debug(s"::: invoked $invoked\n${prettify(scopes)}") scopes match { - case Nil => ZIO.die(UnexpectedCallException(invoked, args)) + case Nil => MatchResult.UnexpectedCall case Scope(expectation, id, update0) :: nextScopes => val update: Expectation[R] => Expectation[R] = updated => { debug(s"::: updated state to: ${updated.state}") @@ -49,7 +56,7 @@ object ProxyFactory { expectation match { case anyExpectation if anyExpectation.state == Saturated => debug("::: skipping saturated expectation") - findMatching(nextScopes) + findMatching(nextScopes, failedMatches) case call @ Call(capability, assertion, returns, _, invocations) if invoked isEqual capability => debug(s"::: matched call $capability") @@ -63,12 +70,13 @@ object ProxyFactory { invocations = id :: invocations ) - UIO.succeedNow(Matched[R, E, A](update(updated), result)) + MatchResult.Success(Matched[R, E, A](update(updated), result)) case false => handleLeafFailure( InvalidArguments(invoked, args, assertion.asInstanceOf[Assertion[Any]]), - nextScopes + nextScopes, + failedMatches ) } @@ -78,7 +86,7 @@ object ProxyFactory { if (invoked.id == capability.id) InvalidPolyType(invoked, args, capability, assertion) else InvalidCapability(invoked, capability, assertion) - handleLeafFailure(invalidCall, nextScopes) + handleLeafFailure(invalidCall, nextScopes, failedMatches) case self @ Chain(children, _, invocations, _) => val scope = children.zipWithIndex.collectFirst { @@ -100,7 +108,7 @@ object ProxyFactory { ) } - findMatching(scope.get :: nextScopes) + findMatching(scope.get :: nextScopes, failedMatches) case self @ And(children, _, invocations, _) => val scopes = children.zipWithIndex.collect { @@ -122,7 +130,7 @@ object ProxyFactory { ) } - findMatching(scopes ++ nextScopes) + findMatching(scopes ++ nextScopes, failedMatches) case self @ Or(children, _, invocations, _) => children.zipWithIndex.find(_._1.state == PartiallySatisfied) match { @@ -143,7 +151,7 @@ object ProxyFactory { } ) - findMatching(scope :: nextScopes) + findMatching(scope :: nextScopes, failedMatches) case None => val scopes = children.zipWithIndex.collect { case (child, index) => Scope[R]( @@ -163,7 +171,7 @@ object ProxyFactory { ) } - findMatching(scopes ++ nextScopes) + findMatching(scopes ++ nextScopes, failedMatches) } case self @ Repeated(expectation, range, state, invocations, started, completed) => @@ -214,7 +222,7 @@ object ProxyFactory { } ) - findMatching(scope :: nextScopes) + findMatching(scope :: nextScopes, failedMatches) } } } @@ -232,13 +240,15 @@ object ProxyFactory { def maximumState(children: List[Expectation[R]]): ExpectationState = children.map(_.state).max - def handleLeafFailure(failure: => InvalidCall, nextScopes: List[Scope[R]]): UIO[Matched[R, E, A]] = - state.failedMatchesRef - .updateAndGet(failure :: _) - .flatMap { failures => - if (nextScopes.isEmpty) ZIO.die(InvalidCallException(failures)) - else findMatching(nextScopes) - } + def handleLeafFailure( + failure: => InvalidCall, + nextScopes: List[Scope[R]], + failedMatches: List[InvalidCall] + ): MatchResult = { + val nextFailed = failure :: failedMatches + if (nextScopes.isEmpty) MatchResult.Failure(nextFailed) + else findMatching(nextScopes, nextFailed) + } def resetTree(expectation: Expectation[R]): Expectation[R] = expectation match { @@ -268,14 +278,25 @@ object ProxyFactory { } for { - id <- state.callsCountRef.updateAndGet(_ + 1) - _ <- state.failedMatchesRef.set(List.empty) - root <- state.expectationRef.get - scope = Scope[R](root, id, identity) - matched <- findMatching(scope :: Nil) - _ = debug(s"::: setting root to\n${prettify(matched.expectation)}") - _ <- state.expectationRef.set(matched.expectation) - output <- matched.result + id <- state.callsCountRef.updateAndGet(_ + 1) + matchResult <- + state.expectationRef.modify { root => + val scope = Scope[R](root, id, identity) + val res = findMatching(scope :: Nil, Nil) + res match { + case MatchResult.Success(matched) => res -> matched.expectation + case MatchResult.UnexpectedCall => res -> root + case MatchResult.Failure(_) => res -> root + } + } + matched <- + matchResult match { + case MatchResult.Success(matched) => ZIO.succeed(matched) + case MatchResult.UnexpectedCall => ZIO.die(UnexpectedCallException(invoked, args)) + case MatchResult.Failure(failures) => ZIO.die(InvalidCallException(failures)) + } + _ = debug(s"::: setting root to\n${prettify(matched.expectation)}") + output <- matched.result } yield output } })