diff --git a/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala b/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala index c6b20fe61f89..4e996432f742 100644 --- a/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala +++ b/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala @@ -2291,21 +2291,56 @@ object ZStreamSpec extends ZIOBaseSpec { } yield assert(result)(equalTo(List(1, 2, 4, 5))) } ), - testM("toInputStream") { - val stream = ZStream(-3, -2, -1, 0, 1, 2, 3).map(_.toByte) - for { - streamResult <- stream.runCollect - inputStreamResult <- stream.toInputStream.use { inputStream => - ZIO.succeed( - Iterator - .continually(inputStream.read) - .takeWhile(_ != -1) - .map(_.toByte) - .toList - ) - } - } yield assert(streamResult)(equalTo(inputStreamResult)) - }, + suite("toInputStream")( + testM("read one-by-one") { + checkM(tinyListOf(Gen.chunkOf(Gen.anyByte))) { chunks => + val content = chunks.flatMap(_.toList) + ZStream.fromChunks(chunks: _*).toInputStream.use[Any, Throwable, TestResult] { is => + ZIO.succeedNow( + assert(Iterator.continually(is.read()).takeWhile(_ != -1).map(_.toByte).toList)(equalTo(content)) + ) + } + } + }, + testM("read in batches") { + checkM(tinyListOf(Gen.chunkOf(Gen.anyByte))) { + chunks => + val content = chunks.flatMap(_.toList) + ZStream.fromChunks(chunks: _*).toInputStream.use[Any, Throwable, TestResult] { is => + val batches: List[(Array[Byte], Int)] = Iterator.continually { + val buf = new Array[Byte](10) + val res = is.read(buf, 0, 4) + (buf, res) + }.takeWhile(_._2 != -1).toList + val combined = batches.flatMap { case (buf, size) => buf.take(size) } + ZIO.succeedNow(assert(combined)(equalTo(content))) + } + } + }, + testM("`available` returns the size of chunk's leftover") { + ZStream + .fromIterable((1 to 10).map(_.toByte)) + .chunkN(3) + .toInputStream + .use[Any, Throwable, TestResult](is => + ZIO.effect { + val cold = is.available() + is.read() + val at1 = is.available() + is.read(new Array[Byte](2)) + val at3 = is.available() + is.read() + val at4 = is.available() + List( + assert(cold)(equalTo(0)), + assert(at1)(equalTo(2)), + assert(at3)(equalTo(0)), + assert(at4)(equalTo(2)) + ).reduce(_ && _) + } + ) + } + ), testM("toIterator") { (for { counter <- Ref.make(0).toManaged_ //Increment and get the value @@ -2773,5 +2808,4 @@ object ZStreamSpec extends ZIOBaseSpec { } testResult <- assertion(chunkCoordination) } yield testResult - } diff --git a/streams/shared/src/main/scala/zio/stream/ZStream.scala b/streams/shared/src/main/scala/zio/stream/ZStream.scala index b8f629424695..7ad925659b01 100644 --- a/streams/shared/src/main/scala/zio/stream/ZStream.scala +++ b/streams/shared/src/main/scala/zio/stream/ZStream.scala @@ -8,6 +8,7 @@ import zio.duration.Duration import zio.internal.UniqueKey import zio.stm.TQueue import zio.stream.internal.Utils.zipChunks +import zio.stream.internal.ZInputStream abstract class ZStream[-R, +E, +O](val process: ZManaged[R, Nothing, ZIO[R, Option[E], Chunk[O]]]) { self => @@ -2650,48 +2651,11 @@ abstract class ZStream[-R, +E, +O](val process: ZManaged[R, Nothing, ZIO[R, Opti * Converts this stream of bytes into a `java.io.InputStream` wrapped in a [[ZManaged]]. * The returned input stream will only be valid within the scope of the ZManaged. */ - def toInputStream(implicit ev0: E <:< Throwable, ev1: O <:< Byte): ZManaged[R, E, java.io.InputStream] = { - val (_, _) = (ev0, ev1) - + def toInputStream(implicit ev0: E <:< Throwable, ev1: O <:< Byte): ZManaged[R, E, java.io.InputStream] = for { runtime <- ZIO.runtime[R].toManaged_ - pull <- process - javaStream = new java.io.InputStream { - val capturedPull = pull.asInstanceOf[ZIO[R, Option[Throwable], Chunk[Byte]]] - var done = false - var nextIndex: Int = -1 - var currChunk: Chunk[Byte] = null - - override def read(): Int = - if (done) -1 - else { - if ((currChunk ne null) && nextIndex < currChunk.size) { - val result = currChunk(nextIndex) - nextIndex += 1 - result & 0xFF - } else { - runtime.unsafeRunSync(capturedPull) match { - case Exit.Failure(cause) => - cause.failureOrCause match { - case Left(None) => - done = true - -1 - case Left(Some(throwable)) => - throw throwable - case Right(otherCause) => - throw FiberFailure(otherCause) - } - - case Exit.Success(chunk) => - currChunk = chunk - nextIndex = 0 - read() - } - } - } - } - } yield javaStream - } + pull <- process.asInstanceOf[ZManaged[R, Nothing, ZIO[R, Option[Throwable], Chunk[Byte]]]] + } yield ZInputStream.fromPull(runtime, pull) /** * Converts this stream into a `scala.collection.Iterator` wrapped in a [[ZManaged]]. diff --git a/streams/shared/src/main/scala/zio/stream/internal/ZInputStream.scala b/streams/shared/src/main/scala/zio/stream/internal/ZInputStream.scala new file mode 100644 index 000000000000..42509f095a61 --- /dev/null +++ b/streams/shared/src/main/scala/zio/stream/internal/ZInputStream.scala @@ -0,0 +1,103 @@ +package zio.stream.internal + +import scala.annotation.tailrec + +import zio.Runtime +import zio.{ Chunk, Exit, FiberFailure, ZIO } + +private[zio] class ZInputStream(chunks: Iterator[Chunk[Byte]]) extends java.io.InputStream { + private var current: Chunk[Byte] = Chunk.empty + private var currentPos: Int = 0 + private var currentChunkLen: Int = 0 + private var done: Boolean = false + + @inline private def availableInCurrentChunk: Int = currentChunkLen - currentPos + + @inline + private def readOne(): Byte = { + val res = current(currentPos) + currentPos += 1 + res + } + + private def loadNext(): Unit = + if (chunks.hasNext) { + current = chunks.next() + currentChunkLen = current.length + currentPos = 0 + } else { + done = true + } + + override def read(): Int = { + @tailrec + def go(): Int = + if (done) { + -1 + } else { + if (availableInCurrentChunk > 0) { + readOne() & 0xFF + } else { + loadNext() + go() + } + } + + go() + } + + override def read(bytes: Array[Byte], off: Int, len: Int): Int = + if (done) { + -1 + } else { + //cater to InputStream specification + if (len != 0) { + val written = doRead(bytes, off, len, 0) + if (written == 0) -1 else written + } else { + 0 + } + } + + @tailrec + private def doRead(bytes: Array[Byte], off: Int, len: Int, written: Int): Int = + if (len <= availableInCurrentChunk) { + readFromCurrentChunk(bytes, off, len) + written + len + } else { + val av = availableInCurrentChunk + readFromCurrentChunk(bytes, off, av) + loadNext() + if (done) { + written + av + } else { + doRead(bytes, off + av, len - av, written + av) + } + } + + private def readFromCurrentChunk(bytes: Array[Byte], off: Int, len: Int): Unit = { + var i: Int = 0 + while (i < len) { + bytes.update(off + i, readOne()) + i += 1 + } + } + + override def available(): Int = availableInCurrentChunk +} + +private[zio] object ZInputStream { + def fromPull[R](runtime: Runtime[R], pull: ZIO[R, Option[Throwable], Chunk[Byte]]): ZInputStream = { + def unfoldPull: Iterator[Chunk[Byte]] = + runtime.unsafeRunSync(pull) match { + case Exit.Success(chunk) => Iterator.single(chunk) ++ unfoldPull + case Exit.Failure(cause) => + cause.failureOrCause match { + case Left(None) => Iterator.empty + case Left(Some(e)) => throw e + case Right(c) => throw FiberFailure(c) + } + } + new ZInputStream(unfoldPull) + } +}