Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions core/src/main/java/io/grpc/internal/DelayedClientCall.java
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ private void drainPendingCalls() {
assert !passThrough;
List<Runnable> toRun = new ArrayList<>();
DelayedListener<RespT> delayedListener ;
drainOut:
while (true) {
synchronized (this) {
if (pendingRunnables.isEmpty()) {
Expand All @@ -311,8 +312,18 @@ private void drainPendingCalls() {
}
for (Runnable runnable : toRun) {
// Must not call transport while lock is held to prevent deadlocks.
// TODO(ejona): exception handling
runnable.run();
try {
runnable.run();
} catch (Throwable t) {
Status status = Status.fromThrowable(t).withDescription("Failed to drain pending calls");
realCall.cancel(status.getDescription(), status.getCause());
synchronized (this) {
pendingRunnables = null;
passThrough = true;
delayedListener = this.delayedListener;
}
break drainOut;
}
}
toRun.clear();
}
Expand Down Expand Up @@ -519,6 +530,7 @@ public void run() {
void drainPendingCallbacks() {
assert !passThrough;
List<Runnable> toRun = new ArrayList<>();
drainOut:
while (true) {
synchronized (this) {
if (pendingCallbacks.isEmpty()) {
Expand All @@ -535,8 +547,15 @@ void drainPendingCallbacks() {
}
for (Runnable runnable : toRun) {
// Avoid calling listener while lock is held to prevent deadlocks.
// TODO(ejona): exception handling
runnable.run();
try {
runnable.run();
} catch (Throwable t) {
synchronized (this) {
pendingCallbacks = null;
passThrough = true;
}
throw t;
}
}
toRun.clear();
}
Expand Down
27 changes: 23 additions & 4 deletions core/src/main/java/io/grpc/internal/DelayedStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ private void drainPendingCalls() {
assert !passThrough;
List<Runnable> toRun = new ArrayList<>();
DelayedStreamListener delayedListener = null;
drainOut:
while (true) {
synchronized (this) {
if (pendingCalls.isEmpty()) {
Expand All @@ -189,8 +190,18 @@ private void drainPendingCalls() {
}
for (Runnable runnable : toRun) {
// Must not call transport while lock is held to prevent deadlocks.
// TODO(ejona): exception handling
runnable.run();
try {
runnable.run();
} catch (Throwable t) {
Status status = Status.fromThrowable(t).withDescription("Failed to drain pending calls");
realStream.cancel(status);
synchronized (this) {
pendingCalls = null;
passThrough = true;
delayedListener = this.delayedListener;
}
break drainOut;
}
}
toRun.clear();
}
Expand Down Expand Up @@ -525,6 +536,7 @@ public void run() {
public void drainPendingCallbacks() {
assert !passThrough;
List<Runnable> toRun = new ArrayList<>();
drainOut:
while (true) {
synchronized (this) {
if (pendingCallbacks.isEmpty()) {
Expand All @@ -541,8 +553,15 @@ public void drainPendingCallbacks() {
}
for (Runnable runnable : toRun) {
// Avoid calling listener while lock is held to prevent deadlocks.
// TODO(ejona): exception handling
runnable.run();
try {
runnable.run();
} catch (Throwable t) {
synchronized (this) {
pendingCallbacks = null;
passThrough = true;
}
throw t;
}
}
toRun.clear();
}
Expand Down
89 changes: 85 additions & 4 deletions core/src/test/java/io/grpc/internal/DelayedClientCallTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@

import static com.google.common.truth.Truth.assertThat;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.common.util.concurrent.MoreExecutors;
Expand Down Expand Up @@ -166,8 +169,8 @@ public void startThenSetCall() {
@Test
@SuppressWarnings("unchecked")
public void cancelThenSetCall() {
DelayedClientCall<String, Integer> delayedClientCall = new DelayedClientCall<>(
callExecutor, fakeClock.getScheduledExecutorService(), null);
DelayedClientCall<String, Integer> delayedClientCall =
new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null);
delayedClientCall.start(listener, new Metadata());
delayedClientCall.request(1);
delayedClientCall.cancel("cancel", new StatusException(Status.CANCELLED));
Expand All @@ -182,8 +185,8 @@ public void cancelThenSetCall() {
@Test
@SuppressWarnings("unchecked")
public void setCallThenCancel() {
DelayedClientCall<String, Integer> delayedClientCall = new DelayedClientCall<>(
callExecutor, fakeClock.getScheduledExecutorService(), null);
DelayedClientCall<String, Integer> delayedClientCall =
new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null);
delayedClientCall.start(listener, new Metadata());
delayedClientCall.request(1);
Runnable r = delayedClientCall.setCall(mockRealCall);
Expand Down Expand Up @@ -229,6 +232,84 @@ public void delayedCallsRunUnderContext() throws Exception {
assertThat(contextKey.get(readyContext.get())).isEqualTo(goldenValue);
}

@Test
@SuppressWarnings("MissingFail")
public void drainPendingCallFails() {
DelayedClientCall<String, Integer> delayedClientCall =
new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null);
delayedClientCall.start(listener, new Metadata());
delayedClientCall.request(1);

final RuntimeException error = new RuntimeException("fail");
org.mockito.Mockito.doAnswer(new org.mockito.stubbing.Answer<Void>() {
@Override
public Void answer(org.mockito.invocation.InvocationOnMock invocation) {
throw error;
}
}).when(mockRealCall).request(1);

Runnable runnable = delayedClientCall.setCall(mockRealCall);
assertThat(runnable).isNotNull();
try {
runnable.run();
} catch (RuntimeException e) {
assertThat(e).isSameInstanceAs(error);
}

verify(mockRealCall).cancel(eq("Failed to drain pending calls"), same(error));
}

@Test
@SuppressWarnings("unchecked")
public void drainPendingCallbacksFails() {
DelayedClientCall<String, Integer> delayedClientCall =
new DelayedClientCall<>(callExecutor, fakeClock.getScheduledExecutorService(), null);
delayedClientCall.start(listener, new Metadata());

final RuntimeException error = new RuntimeException("fail");
org.mockito.Mockito.doAnswer(new org.mockito.stubbing.Answer<Void>() {
@Override
public Void answer(org.mockito.invocation.InvocationOnMock invocation) {
throw error;
}
}).when(listener).onReady();

final AtomicReference<ClientCall.Listener<Integer>> listenerCaptor = new AtomicReference<>();
org.mockito.Mockito.doAnswer(new org.mockito.stubbing.Answer<Void>() {
@Override
public Void answer(org.mockito.invocation.InvocationOnMock invocation) {
ClientCall.Listener<Integer> delayedListener = invocation.getArgument(0);
listenerCaptor.set(delayedListener);
delayedListener.onReady();
return null;
}
}).when(mockRealCall).start(any(ClientCall.Listener.class), any(Metadata.class));

Runnable runnable = delayedClientCall.setCall(mockRealCall);
assertThat(runnable).isNotNull();

try {
runnable.run();
fail("Should have thrown");
} catch (RuntimeException e) {
assertThat(e).isSameInstanceAs(error);
}

ClientCall.Listener<Integer> delayedListener = listenerCaptor.get();
assertThat(delayedListener).isNotNull();

// Verify it transitioned to passThrough by showing it forwards.
try {
delayedListener.onReady();
fail("Should have thrown");
} catch (RuntimeException e) {
assertThat(e).isSameInstanceAs(error);
}

// Verify it was called twice (once during drain, once just now)
verify(listener, times(2)).onReady();
}

private void callMeMaybe(Runnable r) {
if (r != null) {
r.run();
Expand Down
81 changes: 81 additions & 0 deletions core/src/test/java/io/grpc/internal/DelayedStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
Expand Down Expand Up @@ -472,6 +473,86 @@ public Void answer(InvocationOnMock in) {
.matches("\\[test_op_delay=[0-9]+ns, remote_addr=127\\.0\\.0\\.1:443\\]");
}

@Test
@SuppressWarnings({"unchecked", "MissingFail"})
public void drainPendingCallFails() {
stream.start(listener);
stream.request(1);
final RuntimeException error = new RuntimeException("fail");
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
throw error;
}
}).when(realStream).request(1);

Runnable runnable = stream.setStream(realStream);
assertNotNull(runnable);
try {
runnable.run();
} catch (RuntimeException e) {
assertThat(e).isSameInstanceAs(error);
}

ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
verify(realStream).cancel(statusCaptor.capture());
assertThat(statusCaptor.getValue().getCode()).isEqualTo(Status.Code.UNKNOWN);
assertThat(statusCaptor.getValue().getCause()).isSameInstanceAs(error);

verify(realStream).start(listenerCaptor.capture());
listenerCaptor.getValue().closed(
statusCaptor.getValue(), RpcProgress.PROCESSED, new Metadata());
verify(listener).closed(
same(statusCaptor.getValue()), any(RpcProgress.class), any(Metadata.class));
}

@Test
@SuppressWarnings("unchecked")
public void drainPendingCallbacksFails() {
stream.start(listener);
final RuntimeException error = new RuntimeException("fail");
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
throw error;
}
}).when(listener).onReady();

doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
ClientStreamListener delayedListener = invocation.getArgument(0);
delayedListener.onReady();
return null;
}
}).when(realStream).start(any(ClientStreamListener.class));

Runnable runnable = stream.setStream(realStream);
assertNotNull(runnable);

try {
runnable.run();
fail("Should have thrown");
} catch (RuntimeException e) {
assertThat(e).isSameInstanceAs(error);
}

verify(realStream).start(listenerCaptor.capture());
ClientStreamListener delayedListener = listenerCaptor.getValue();

// Verify it transitioned to passThrough. If it didn't, this might NPE or buffer.
// If it is passThrough, it will forward to the listener, which we know throws.
try {
delayedListener.onReady();
fail("Should have thrown");
} catch (RuntimeException e) {
assertThat(e).isSameInstanceAs(error);
}

// Verify it was called twice (once during drain, once just now)
verify(listener, times(2)).onReady();
}

private void callMeMaybe(Runnable r) {
if (r != null) {
r.run();
Expand Down
Loading