diff --git a/core/src/main/java/io/grpc/internal/DelayedClientCall.java b/core/src/main/java/io/grpc/internal/DelayedClientCall.java index b568bb12c46..9d3c4abe0e2 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientCall.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientCall.java @@ -294,6 +294,7 @@ private void drainPendingCalls() { assert !passThrough; List toRun = new ArrayList<>(); DelayedListener delayedListener ; + drainOut: while (true) { synchronized (this) { if (pendingRunnables.isEmpty()) { @@ -311,8 +312,18 @@ private void drainPendingCalls() { } for (Runnable runnable : toRun) { // Must not call transport while lock is held to prevent deadlocks. - // TODO(ejona): exception handling - runnable.run(); + try { + runnable.run(); + } catch (Throwable t) { + Status status = Status.fromThrowable(t).withDescription("Failed to drain pending calls"); + realCall.cancel(status.getDescription(), status.getCause()); + synchronized (this) { + pendingRunnables = null; + passThrough = true; + delayedListener = this.delayedListener; + } + break drainOut; + } } toRun.clear(); } @@ -519,6 +530,7 @@ public void run() { void drainPendingCallbacks() { assert !passThrough; List toRun = new ArrayList<>(); + drainOut: while (true) { synchronized (this) { if (pendingCallbacks.isEmpty()) { @@ -535,8 +547,15 @@ void drainPendingCallbacks() { } for (Runnable runnable : toRun) { // Avoid calling listener while lock is held to prevent deadlocks. - // TODO(ejona): exception handling - runnable.run(); + try { + runnable.run(); + } catch (Throwable t) { + synchronized (this) { + pendingCallbacks = null; + passThrough = true; + } + throw t; + } } toRun.clear(); } diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index a2b1e963ac5..b44b27500d5 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -172,6 +172,7 @@ private void drainPendingCalls() { assert !passThrough; List toRun = new ArrayList<>(); DelayedStreamListener delayedListener = null; + drainOut: while (true) { synchronized (this) { if (pendingCalls.isEmpty()) { @@ -189,8 +190,18 @@ private void drainPendingCalls() { } for (Runnable runnable : toRun) { // Must not call transport while lock is held to prevent deadlocks. - // TODO(ejona): exception handling - runnable.run(); + try { + runnable.run(); + } catch (Throwable t) { + Status status = Status.fromThrowable(t).withDescription("Failed to drain pending calls"); + realStream.cancel(status); + synchronized (this) { + pendingCalls = null; + passThrough = true; + delayedListener = this.delayedListener; + } + break drainOut; + } } toRun.clear(); } @@ -525,6 +536,7 @@ public void run() { public void drainPendingCallbacks() { assert !passThrough; List toRun = new ArrayList<>(); + drainOut: while (true) { synchronized (this) { if (pendingCallbacks.isEmpty()) { @@ -541,8 +553,15 @@ public void drainPendingCallbacks() { } for (Runnable runnable : toRun) { // Avoid calling listener while lock is held to prevent deadlocks. - // TODO(ejona): exception handling - runnable.run(); + try { + runnable.run(); + } catch (Throwable t) { + synchronized (this) { + pendingCallbacks = null; + passThrough = true; + } + throw t; + } } toRun.clear(); } diff --git a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java index ff131d29975..f00be9a5de8 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java @@ -18,10 +18,13 @@ import static com.google.common.truth.Truth.assertThat; import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import com.google.common.util.concurrent.MoreExecutors; @@ -166,8 +169,8 @@ public void startThenSetCall() { @Test @SuppressWarnings("unchecked") public void cancelThenSetCall() { - DelayedClientCall delayedClientCall = new DelayedClientCall<>( - callExecutor, fakeClock.getScheduledExecutorService(), null); + DelayedClientCall delayedClientCall = + new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null); delayedClientCall.start(listener, new Metadata()); delayedClientCall.request(1); delayedClientCall.cancel("cancel", new StatusException(Status.CANCELLED)); @@ -182,8 +185,8 @@ public void cancelThenSetCall() { @Test @SuppressWarnings("unchecked") public void setCallThenCancel() { - DelayedClientCall delayedClientCall = new DelayedClientCall<>( - callExecutor, fakeClock.getScheduledExecutorService(), null); + DelayedClientCall delayedClientCall = + new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null); delayedClientCall.start(listener, new Metadata()); delayedClientCall.request(1); Runnable r = delayedClientCall.setCall(mockRealCall); @@ -229,6 +232,84 @@ public void delayedCallsRunUnderContext() throws Exception { assertThat(contextKey.get(readyContext.get())).isEqualTo(goldenValue); } + @Test + @SuppressWarnings("MissingFail") + public void drainPendingCallFails() { + DelayedClientCall delayedClientCall = + new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null); + delayedClientCall.start(listener, new Metadata()); + delayedClientCall.request(1); + + final RuntimeException error = new RuntimeException("fail"); + org.mockito.Mockito.doAnswer(new org.mockito.stubbing.Answer() { + @Override + public Void answer(org.mockito.invocation.InvocationOnMock invocation) { + throw error; + } + }).when(mockRealCall).request(1); + + Runnable runnable = delayedClientCall.setCall(mockRealCall); + assertThat(runnable).isNotNull(); + try { + runnable.run(); + } catch (RuntimeException e) { + assertThat(e).isSameInstanceAs(error); + } + + verify(mockRealCall).cancel(eq("Failed to drain pending calls"), same(error)); + } + + @Test + @SuppressWarnings("unchecked") + public void drainPendingCallbacksFails() { + DelayedClientCall delayedClientCall = + new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null); + delayedClientCall.start(listener, new Metadata()); + + final RuntimeException error = new RuntimeException("fail"); + org.mockito.Mockito.doAnswer(new org.mockito.stubbing.Answer() { + @Override + public Void answer(org.mockito.invocation.InvocationOnMock invocation) { + throw error; + } + }).when(listener).onReady(); + + final AtomicReference> listenerCaptor = new AtomicReference<>(); + org.mockito.Mockito.doAnswer(new org.mockito.stubbing.Answer() { + @Override + public Void answer(org.mockito.invocation.InvocationOnMock invocation) { + ClientCall.Listener delayedListener = invocation.getArgument(0); + listenerCaptor.set(delayedListener); + delayedListener.onReady(); + return null; + } + }).when(mockRealCall).start(any(ClientCall.Listener.class), any(Metadata.class)); + + Runnable runnable = delayedClientCall.setCall(mockRealCall); + assertThat(runnable).isNotNull(); + + try { + runnable.run(); + fail("Should have thrown"); + } catch (RuntimeException e) { + assertThat(e).isSameInstanceAs(error); + } + + ClientCall.Listener delayedListener = listenerCaptor.get(); + assertThat(delayedListener).isNotNull(); + + // Verify it transitioned to passThrough by showing it forwards. + try { + delayedListener.onReady(); + fail("Should have thrown"); + } catch (RuntimeException e) { + assertThat(e).isSameInstanceAs(error); + } + + // Verify it was called twice (once during drain, once just now) + verify(listener, times(2)).onReady(); + } + private void callMeMaybe(Runnable r) { if (r != null) { r.run(); diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index 12c32fcf126..01f49af2b36 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; @@ -472,6 +473,86 @@ public Void answer(InvocationOnMock in) { .matches("\\[test_op_delay=[0-9]+ns, remote_addr=127\\.0\\.0\\.1:443\\]"); } + @Test + @SuppressWarnings({"unchecked", "MissingFail"}) + public void drainPendingCallFails() { + stream.start(listener); + stream.request(1); + final RuntimeException error = new RuntimeException("fail"); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + throw error; + } + }).when(realStream).request(1); + + Runnable runnable = stream.setStream(realStream); + assertNotNull(runnable); + try { + runnable.run(); + } catch (RuntimeException e) { + assertThat(e).isSameInstanceAs(error); + } + + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(realStream).cancel(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.UNKNOWN); + assertThat(statusCaptor.getValue().getCause()).isSameInstanceAs(error); + + verify(realStream).start(listenerCaptor.capture()); + listenerCaptor.getValue().closed( + statusCaptor.getValue(), RpcProgress.PROCESSED, new Metadata()); + verify(listener).closed( + same(statusCaptor.getValue()), any(RpcProgress.class), any(Metadata.class)); + } + + @Test + @SuppressWarnings("unchecked") + public void drainPendingCallbacksFails() { + stream.start(listener); + final RuntimeException error = new RuntimeException("fail"); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + throw error; + } + }).when(listener).onReady(); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + ClientStreamListener delayedListener = invocation.getArgument(0); + delayedListener.onReady(); + return null; + } + }).when(realStream).start(any(ClientStreamListener.class)); + + Runnable runnable = stream.setStream(realStream); + assertNotNull(runnable); + + try { + runnable.run(); + fail("Should have thrown"); + } catch (RuntimeException e) { + assertThat(e).isSameInstanceAs(error); + } + + verify(realStream).start(listenerCaptor.capture()); + ClientStreamListener delayedListener = listenerCaptor.getValue(); + + // Verify it transitioned to passThrough. If it didn't, this might NPE or buffer. + // If it is passThrough, it will forward to the listener, which we know throws. + try { + delayedListener.onReady(); + fail("Should have thrown"); + } catch (RuntimeException e) { + assertThat(e).isSameInstanceAs(error); + } + + // Verify it was called twice (once during drain, once just now) + verify(listener, times(2)).onReady(); + } + private void callMeMaybe(Runnable r) { if (r != null) { r.run();