diff --git a/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java b/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java index 2799dfccb61..56ba9fd3214 100644 --- a/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java +++ b/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java @@ -99,6 +99,8 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { // 4c. active, prev!=null[handlerRemoved]: channel will be closed out-of-band by buffered write. // 4d. active, prev!=null[connect]: impossible, channel can't be active after a failed connect. if (ctx.channel().isActive() && previousFailure == null) { + ctx.fireExceptionCaught(cause); + final class LogOnFailure implements ChannelFutureListener { @Override public void operationComplete(ChannelFuture future) { diff --git a/netty/src/test/java/io/grpc/netty/WriteBufferingAndExceptionHandlerTest.java b/netty/src/test/java/io/grpc/netty/WriteBufferingAndExceptionHandlerTest.java index b99a9386fcf..9c2022cfc14 100644 --- a/netty/src/test/java/io/grpc/netty/WriteBufferingAndExceptionHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/WriteBufferingAndExceptionHandlerTest.java @@ -31,8 +31,11 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultEventLoop; @@ -41,6 +44,7 @@ import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalServerChannel; import java.net.ConnectException; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -381,4 +385,61 @@ public void uncaughtReadFails() throws Exception { assertThat(status.getDescription()).contains("channelRead() missed"); } } + + @Test + public void handshakeFailure_isPropagatedOnce() throws Exception { + AtomicInteger exceptionCount = new AtomicInteger(); + CountDownLatch latch = new CountDownLatch(1); + + ChannelHandler observer = + new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + exceptionCount.incrementAndGet(); + latch.countDown(); + } + }; + + WriteBufferingAndExceptionHandler handler = + new WriteBufferingAndExceptionHandler(new ChannelHandlerAdapter() {}); + + LocalAddress addr = new LocalAddress("local"); + + ChannelFuture cf = + new Bootstrap() + .channel(LocalChannel.class) + .group(group) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(handler); + ch.pipeline().addLast(observer); + } + }) + .register(); + + chan = cf.channel(); + cf.sync(); + + ChannelFuture sf = + new ServerBootstrap() + .group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInboundHandlerAdapter() {}) + .bind(addr); + server = sf.channel(); + sf.sync(); + + chan.connect(addr).sync(); + + RuntimeException handshakeFailure = + Status.UNAVAILABLE.withDescription("handshake failed").asRuntimeException(); + + chan.pipeline().fireExceptionCaught(handshakeFailure); + chan.pipeline().fireExceptionCaught(new RuntimeException("Second")); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertThat(exceptionCount.get()).isEqualTo(1); + } }