diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 9f6faad245..9fd98ef5c4 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -1477,7 +1477,11 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, } def endSupply(result: Either[Throwable, Unit]): F2[Unit] = - buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN(Int.MaxValue) + buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN( + // enough supply for 2 iterations of the race loop in case of upstream + // interruption: so that downstream can terminate immediately + outputLong * 2 + ) def endDemand(result: Either[Throwable, Unit]): F2[Unit] = buffer.update(_.copy(endOfDemand = Some(result))) *> demand.releaseN(Int.MaxValue) diff --git a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala index fb49f21339..3bf473d795 100644 --- a/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala @@ -23,7 +23,7 @@ package fs2 import cats.effect.kernel.Deferred import cats.effect.kernel.Ref -import cats.effect.std.{Semaphore, Queue} +import cats.effect.std.{Queue, Semaphore} import cats.effect.testkit.TestControl import cats.effect.{IO, SyncIO} import cats.syntax.all._ @@ -34,6 +34,7 @@ import org.scalacheck.Prop.forAll import scala.concurrent.duration._ import scala.concurrent.TimeoutException +import scala.util.control.NoStackTrace class StreamCombinatorsSuite extends Fs2Suite { override def munitIOTimeout = 1.minute @@ -834,6 +835,74 @@ class StreamCombinatorsSuite extends Fs2Suite { ) .assertEquals(0.millis) } + + test("upstream failures are propagated downstream") { + TestControl.executeEmbed { + case object SevenNotAllowed extends NoStackTrace + + val source = Stream + .iterate(0)(_ + 1) + .covary[IO] + .evalTap(n => IO.raiseError(SevenNotAllowed).whenA(n == 7)) + + val downstream = source.groupWithin(100, 2.seconds).map(_.toList) + + val expected = List(List(1, 2, 3, 4, 5, 6)) + + downstream.assertEmits(expected).intercept[SevenNotAllowed.type] + } + } + + test( + "upstream interruption causes immediate downstream termination with all elements being emitted" + ) { + + val sourceTimeout = 5.5.seconds + val downstreamTimeout = sourceTimeout + 2.seconds + + TestControl + .executeEmbed { + val source: Stream[IO, Int] = + Stream + .iterate(0)(_ + 1) + .covary[IO] + .meteredStartImmediately(1.second) + .interruptAfter(sourceTimeout) + + // large chunkSize and timeout (no emissions expected in the window + // specified, unless source ends, due to interruption or + // natural termination (i.e runs out of elements) + val downstream: Stream[IO, Chunk[Int]] = + source.groupWithin(Int.MaxValue, 1.day) + + downstream.compile.lastOrError + .timeout(downstreamTimeout) + .map(_.toList) + .timed + } + .assertEquals( + // downstream ended immediately (i.e timeLapsed = sourceTimeout) + // emitting whatever was accumulated at the time of interruption + (sourceTimeout, List(0, 1, 2, 3, 4, 5)) + ) + } + + test("stress test: all elements are processed") { + val rangeLength = 10000 + + Stream + .eval(Ref.of[IO, Int](0)) + .flatMap { counter => + Stream + .range(0, rangeLength) + .covary[IO] + .groupWithin(4096, 100.micros) + .evalTap(ch => counter.update(_ + ch.size)) *> Stream.eval(counter.get) + } + .compile + .lastOrError + .assertEquals(rangeLength) + } } property("head")(forAll((s: Stream[Pure, Int]) => assertEquals(s.head.toList, s.toList.take(1))))