diff --git a/rlib-network/src/main/java/javasabr/rlib/network/exception/ConnectionClosedException.java b/rlib-network/src/main/java/javasabr/rlib/network/exception/ConnectionClosedException.java new file mode 100644 index 00000000..c5e90114 --- /dev/null +++ b/rlib-network/src/main/java/javasabr/rlib/network/exception/ConnectionClosedException.java @@ -0,0 +1,12 @@ +package javasabr.rlib.network.exception; + +public class ConnectionClosedException extends NetworkException { + + public ConnectionClosedException(String remoteAddress) { + super("Connection closed: %s".formatted(remoteAddress)); + } + + public ConnectionClosedException(String remoteAddress, Throwable cause) { + super("Connection closed: %s".formatted(remoteAddress), cause); + } +} diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java index 7e8e04c6..712614dd 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java @@ -4,18 +4,23 @@ import java.nio.channels.AsynchronousChannel; import java.nio.channels.AsynchronousSocketChannel; +import java.util.Collection; import java.util.Deque; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.StampedLock; import java.util.function.BiConsumer; +import javasabr.rlib.collections.array.Array; import javasabr.rlib.collections.array.ArrayFactory; +import javasabr.rlib.collections.array.LockableArray; import javasabr.rlib.collections.array.MutableArray; import javasabr.rlib.collections.deque.DequeFactory; +import javasabr.rlib.collections.operation.LockableOperations; import javasabr.rlib.network.BufferAllocator; import javasabr.rlib.network.Connection; import javasabr.rlib.network.Network; import javasabr.rlib.network.UnsafeConnection; +import javasabr.rlib.network.exception.ConnectionClosedException; import javasabr.rlib.network.packet.NetworkPacketReader; import javasabr.rlib.network.packet.NetworkPacketWriter; import javasabr.rlib.network.packet.ReadableNetworkPacket; @@ -64,6 +69,8 @@ public WritablePacketWithFeedback(CompletableFuture attachment, Writabl final MutableArray>> validPacketSubscribers; final MutableArray>> invalidPacketSubscribers; + final LockableArray> activeSinks; + final LockableOperations>> activeSinksOperations; final int maxPacketsByRead; @@ -84,6 +91,8 @@ public AbstractConnection( this.closed = new AtomicBoolean(false); this.validPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class); this.invalidPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class); + this.activeSinks = ArrayFactory.stampedLockBasedArray(FluxSink.class); + this.activeSinksOperations = activeSinks.operations(); this.remoteAddress = String.valueOf(NetworkUtils.getRemoteAddress(channel)); } @@ -134,10 +143,12 @@ protected void registerFluxOnReceivedEvents( validPacketSubscribers.add(validListener); invalidPacketSubscribers.add(invalidListener); + activeSinksOperations.inWriteLock(sink, Collection::add); sink.onDispose(() -> { validPacketSubscribers.remove(validListener); validPacketSubscribers.remove(invalidListener); + activeSinksOperations.inWriteLock(sink, Collection::remove); }); network.inNetworkThread(() -> packetReader().startRead()); @@ -146,14 +157,22 @@ protected void registerFluxOnReceivedEvents( protected void registerFluxOnReceivedValidPackets(FluxSink> sink) { BiConsumer> listener = (connection, packet) -> sink.next(packet); validPacketSubscribers.add(listener); - sink.onDispose(() -> validPacketSubscribers.remove(listener)); + activeSinksOperations.inWriteLock(sink, Collection::add); + sink.onDispose(() -> { + validPacketSubscribers.remove(listener); + activeSinksOperations.inWriteLock(sink, Collection::remove); + }); network.inNetworkThread(() -> packetReader().startRead()); } protected void registerFluxOnReceivedInvalidPackets(FluxSink> sink) { BiConsumer> listener = (connection, packet) -> sink.next(packet); invalidPacketSubscribers.add(listener); - sink.onDispose(() -> invalidPacketSubscribers.remove(listener)); + activeSinksOperations.inWriteLock(sink, Collection::add); + sink.onDispose(() -> { + invalidPacketSubscribers.remove(listener); + activeSinksOperations.inWriteLock(sink, Collection::remove); + }); network.inNetworkThread(() -> packetReader().startRead()); } @@ -184,6 +203,28 @@ protected void doClose() { clearWaitPackets(); packetReader().close(); packetWriter().close(); + notifyActiveSinks(); + } + + protected void notifyActiveSinks() { + Boolean noActiveSinks = activeSinksOperations.getInReadLock(Array::isEmpty); + if (noActiveSinks) { + return; + } + notifySinksWithError(new ConnectionClosedException(remoteAddress)); + activeSinksOperations.inWriteLock(Collection::clear); + } + + protected void notifySinksWithError(Throwable error) { + Array> localActiveSinks = activeSinksOperations.getInReadLock(Array::copyOf); + localActiveSinks.iterations().forEach( + error, (sink, exc) -> { + try { + sink.error(exc); + } catch (RuntimeException e) { + log.error(e.getMessage(), "Failed to notify sink of connection closure: "::formatted); + } + }); } /** diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java index de7e5b8a..ee683870 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java @@ -461,10 +461,14 @@ protected void handleFailedReceiving(Throwable exception, ByteBuffer readingBuff retryReadLater(); } } - case AsynchronousCloseException ex -> - log.info(remoteAddress(), "[%s] Connection was closed"::formatted); - case ClosedChannelException ex -> - log.info(remoteAddress(), "[%s] Connection was closed"::formatted); + case AsynchronousCloseException ex -> { + log.info(remoteAddress(), "[%s] Connection was closed"::formatted); + connection.close(); + } + case ClosedChannelException ex -> { + log.info(remoteAddress(), "[%s] Connection was closed"::formatted); + connection.close(); + } default -> { log.error(exception); connection.close(); diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java index 6ab75309..51701e2f 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java @@ -76,6 +76,7 @@ protected AbstractSslNetworkPacketReader( protected void handleReceivedData(int receivedBytes, ByteBuffer readingBuffer) { if (receivedBytes == -1) { doHandshake(sslNetworkBuffer(), -1); + connection.close(); return; } super.handleReceivedData(receivedBytes, readingBuffer); diff --git a/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java b/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java new file mode 100644 index 00000000..4b108fc0 --- /dev/null +++ b/rlib-network/src/test/java/javasabr/rlib/network/ConnectionCloseTest.java @@ -0,0 +1,98 @@ +package javasabr.rlib.network; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import javasabr.rlib.network.exception.ConnectionClosedException; +import javasabr.rlib.network.impl.AbstractConnection; +import javasabr.rlib.network.impl.DefaultConnection; +import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket; +import javasabr.rlib.network.packet.impl.StringWritableNetworkPacket; +import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry; +import javasabr.rlib.network.util.NetworkUtils; +import javax.net.ssl.SSLContext; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Test; + +public class ConnectionCloseTest extends BaseNetworkTest { + + @Test + void shouldPropagateConnectionCloseToClient() throws InterruptedException { + // given + var packetRegistry = ReadableNetworkPacketRegistry.of( + DefaultReadableNetworkPacket.class, + DefaultConnection.class, + DefaultNetworkTest.ServerPackets.RequestEchoMessage.class, + DefaultNetworkTest.ServerPackets.RequestServerTime.class); + var serverNetwork = NetworkFactory.defaultServerNetwork(packetRegistry); + InetSocketAddress serverAddress = serverNetwork.start(); + serverNetwork.onAccept(AbstractConnection::close); + var clientNetwork = NetworkFactory.defaultClientNetwork(packetRegistry); + CountDownLatch closeLatch = new CountDownLatch(1); + + // when + clientNetwork + .connectReactive(serverAddress) + .flatMapMany(AbstractConnection::receivedEvents) + .doOnError(e -> { + if (e instanceof ConnectionClosedException) { + closeLatch.countDown(); + } + }) + .subscribe(); + + // then + assertThat(closeLatch.await(5000, TimeUnit.MILLISECONDS)) + .as("Client should be notified that connection is closed") + .isTrue(); + clientNetwork.shutdown(); + serverNetwork.shutdown(); + } + + @Test + @SneakyThrows + void shouldCloseServerConnectionWhenClientClosesTcpChannelAbruptly() { + // Given: established SSL connection with completed handshake + InputStream keystoreFile = ConnectionCloseTest.class.getResourceAsStream("/ssl/rlib_test_cert.p12"); + SSLContext serverSslContext = NetworkUtils.createSslContext(keystoreFile, "test"); + SSLContext clientSslContext = NetworkUtils.createAllTrustedClientSslContext(); + + try (var testNetwork = buildStringSSLNetwork(serverSslContext, clientSslContext)) { + var serverConnection = testNetwork.serverToClient; + var clientConnection = testNetwork.clientToServer; + + // Register handler to start reading on server side + CountDownLatch dataReceivedLatch = new CountDownLatch(1); + serverConnection.onReceiveValidPacket((conn, packet) -> dataReceivedLatch.countDown()); + + // Send data to complete SSL handshake and deliver a packet + clientConnection.sendInBackground(new StringWritableNetworkPacket<>("handshake")); + + // Wait for the handshake to complete and data to be received + assertThat(dataReceivedLatch.await(5, TimeUnit.SECONDS)) + .as("SSL handshake should complete and data should be received by server") + .isTrue(); + + // When: close client's raw TCP channel without SSL close_notify + clientConnection.channel().close(); + + assertThat(awaitMillis(5000, serverConnection::closed)) + .as("Server connection should be closed after receiving EOF from abruptly closed client channel") + .isTrue(); + } + } + + private static boolean awaitMillis(long millis, Supplier a) throws InterruptedException { + for (int i = 0; i < millis / 100; i++) { + if (a.get()) { + return true; + } + Thread.sleep(100); + } + return false; + } +}