diff --git a/.evergreen/.evg.yml b/.evergreen/.evg.yml index 525861928f3..3524dca8002 100644 --- a/.evergreen/.evg.yml +++ b/.evergreen/.evg.yml @@ -2430,18 +2430,27 @@ buildvariants: - name: "socket-test-task" - matrix_name: "tests-netty" - matrix_spec: { auth: "noauth", ssl: "*", jdk: "jdk8", version: [ "7.0" ], topology: "replicaset", os: "linux", + matrix_spec: { auth: "noauth", ssl: "*", jdk: "jdk8", version: [ "8.0" ], topology: "replicaset", os: "linux", async-transport: "netty" } display_name: "Netty: ${version} ${topology} ${ssl} ${auth} ${jdk} ${os} " tags: [ "tests-netty-variant" ] + tasks: + - name: "test-core-task" + - name: "test-reactive-task" + + - matrix_name: "tests-netty-compression" + matrix_spec: { compressor: "snappy", auth: "noauth", ssl: "nossl", jdk: "jdk21", version: [ "8.0" ], topology: "replicaset", + os: "linux", async-transport: "netty" } + display_name: "Netty with Compression '${compressor}': ${version} ${topology} ${ssl} ${auth} ${jdk} ${os} " + tags: [ "tests-netty-variant" ] tasks: - name: "test-reactive-task" - name: "test-core-task" - matrix_name: "tests-netty-ssl-provider" - matrix_spec: { auth: "auth", ssl: "ssl", jdk: "jdk8", version: [ "7.0" ], topology: "replicaset", os: "linux", + matrix_spec: { auth: "auth", ssl: "ssl", jdk: "jdk8", version: [ "8.0" ], topology: "replicaset", os: "linux", async-transport: "netty", netty-ssl-provider: "*" } - display_name: "Netty SSL provider: ${version} ${topology} ${ssl} SslProvider.${netty-ssl-provider} ${auth} ${jdk} ${os} " + display_name: "Netty with SSL provider '${netty-ssl-provider}': ${version} ${topology} ${ssl} ${auth} ${jdk} ${os} " tags: [ "tests-netty-variant" ] tasks: - name: "test-reactive-task" diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 10bd5bc107d..1ffe7a229a5 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -67,14 +67,17 @@ provision_ssl () { } provision_multi_mongos_uri_for_ssl () { - # Arguments for auth + SSL - if [ "$AUTH" != "noauth" ] || [ "$TOPOLOGY" == "replica_set" ]; then - export MONGODB_URI="${MONGODB_URI}&ssl=true&sslInvalidHostNameAllowed=true" - export MULTI_MONGOS_URI="${MULTI_MONGOS_URI}&ssl=true&sslInvalidHostNameAllowed=true" - else - export MONGODB_URI="${MONGODB_URI}/?ssl=true&sslInvalidHostNameAllowed=true" - export MULTI_MONGOS_URI="${MULTI_MONGOS_URI}/?ssl=true&sslInvalidHostNameAllowed=true" - fi + if [[ "$MONGODB_URI" == *"?"* ]]; then + export MONGODB_URI="${MONGODB_URI}&ssl=true&sslInvalidHostNameAllowed=true" + else + export MONGODB_URI="${MONGODB_URI}/?ssl=true&sslInvalidHostNameAllowed=true" + fi + + if [[ "$MULTI_MONGOS_URI" == *"?"* ]]; then + export MULTI_MONGOS_URI="${MULTI_MONGOS_URI}&ssl=true&sslInvalidHostNameAllowed=true" + else + export MULTI_MONGOS_URI="${MULTI_MONGOS_URI}/?ssl=true&sslInvalidHostNameAllowed=true" + fi } ############################################ @@ -136,6 +139,8 @@ echo "Running tests with Java ${JAVA_VERSION}" --stacktrace --info --continue ${TESTS} | tee -a logs.txt if grep -q 'LEAK:' logs.txt ; then - echo "Netty Leak detected, please inspect build log" + echo "===============================================" + echo " Netty Leak detected, please inspect build log " + echo "===============================================" exit 1 fi diff --git a/bson/src/main/org/bson/BsonDocument.java b/bson/src/main/org/bson/BsonDocument.java index 87625de8dbd..f52a1f25f7d 100644 --- a/bson/src/main/org/bson/BsonDocument.java +++ b/bson/src/main/org/bson/BsonDocument.java @@ -921,10 +921,12 @@ private static class SerializationProxy implements Serializable { new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); this.bytes = new byte[buffer.size()]; int curPos = 0; - for (ByteBuf cur : buffer.getByteBuffers()) { + List byteBuffers = buffer.getByteBuffers(); + for (ByteBuf cur : byteBuffers) { System.arraycopy(cur.array(), cur.position(), bytes, curPos, cur.limit()); curPos += cur.position(); } + byteBuffers.forEach(ByteBuf::release); } private Object readResolve() { diff --git a/config/spotbugs/exclude.xml b/config/spotbugs/exclude.xml index 20684680865..d1647b1f149 100644 --- a/config/spotbugs/exclude.xml +++ b/config/spotbugs/exclude.xml @@ -290,4 +290,10 @@ + + + + + + diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonArray.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonArray.java index e02cee12629..5442e81de68 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonArray.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonArray.java @@ -23,6 +23,7 @@ import org.bson.ByteBuf; import org.bson.io.ByteBufferBsonInput; +import java.io.Closeable; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; @@ -33,20 +34,31 @@ import static com.mongodb.internal.connection.ByteBufBsonHelper.readBsonValue; -final class ByteBufBsonArray extends BsonArray { +final class ByteBufBsonArray extends BsonArray implements Closeable { private final ByteBuf byteBuf; + /** + * List of resources that need to be closed when this array is closed. + * Tracks the main ByteBuf and iterator duplicates. Iterator buffers are automatically + * removed and released when iteration completes normally to prevent memory accumulation. + */ + private final List trackedResources = new ArrayList<>(); + private boolean closed; + ByteBufBsonArray(final ByteBuf byteBuf) { this.byteBuf = byteBuf; + trackedResources.add(byteBuf::release); } @Override public Iterator iterator() { + ensureOpen(); return new ByteBufBsonArrayIterator(); } @Override public List getValues() { + ensureOpen(); List values = new ArrayList<>(); for (BsonValue cur: this) { //noinspection UseBulkOperation @@ -59,6 +71,7 @@ public List getValues() { @Override public int size() { + ensureOpen(); int size = 0; for (BsonValue ignored : this) { size++; @@ -68,11 +81,13 @@ public int size() { @Override public boolean isEmpty() { + ensureOpen(); return !iterator().hasNext(); } @Override public boolean equals(final Object o) { + ensureOpen(); if (o == this) { return true; } @@ -91,6 +106,7 @@ public boolean equals(final Object o) { @Override public int hashCode() { + ensureOpen(); int hashCode = 1; for (BsonValue cur : this) { hashCode = 31 * hashCode + (cur == null ? 0 : cur.hashCode()); @@ -100,6 +116,7 @@ public int hashCode() { @Override public boolean contains(final Object o) { + ensureOpen(); for (BsonValue cur : this) { if (Objects.equals(o, cur)) { return true; @@ -111,6 +128,7 @@ public boolean contains(final Object o) { @Override public Object[] toArray() { + ensureOpen(); Object[] retVal = new Object[size()]; Iterator it = iterator(); for (int i = 0; i < retVal.length; i++) { @@ -122,6 +140,7 @@ public Object[] toArray() { @Override @SuppressWarnings("unchecked") public T[] toArray(final T[] a) { + ensureOpen(); int size = size(); T[] retVal = a.length >= size ? a : (T[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size); Iterator it = iterator(); @@ -133,6 +152,7 @@ public T[] toArray(final T[] a) { @Override public boolean containsAll(final Collection c) { + ensureOpen(); for (Object e : c) { if (!contains(e)) { return false; @@ -143,6 +163,7 @@ public boolean containsAll(final Collection c) { @Override public BsonValue get(final int index) { + ensureOpen(); if (index < 0) { throw new IndexOutOfBoundsException("Index out of range: " + index); } @@ -159,6 +180,7 @@ public BsonValue get(final int index) { @Override public int indexOf(final Object o) { + ensureOpen(); int i = 0; for (BsonValue cur : this) { if (Objects.equals(o, cur)) { @@ -172,6 +194,7 @@ public int indexOf(final Object o) { @Override public int lastIndexOf(final Object o) { + ensureOpen(); ListIterator listIterator = listIterator(size()); while (listIterator.hasPrevious()) { if (Objects.equals(o, listIterator.previous())) { @@ -183,17 +206,20 @@ public int lastIndexOf(final Object o) { @Override public ListIterator listIterator() { + ensureOpen(); return listIterator(0); } @Override public ListIterator listIterator(final int index) { + ensureOpen(); // Not the most efficient way to do this, but unlikely anyone will notice in practice return new ArrayList<>(this).listIterator(index); } @Override public List subList(final int fromIndex, final int toIndex) { + ensureOpen(); if (fromIndex < 0) { throw new IndexOutOfBoundsException("fromIndex = " + fromIndex); } @@ -234,6 +260,7 @@ public boolean addAll(final Collection c) { @Override public boolean addAll(final int index, final Collection c) { + ensureOpen(); throw new UnsupportedOperationException(READ_ONLY_MESSAGE); } @@ -267,11 +294,43 @@ public BsonValue remove(final int index) { throw new UnsupportedOperationException(READ_ONLY_MESSAGE); } + @Override + public void close(){ + if (!closed) { + for (Closeable closeable : trackedResources) { + try { + closeable.close(); + } catch (Exception e) { + // Log and continue closing other resources + } + } + trackedResources.clear(); + closed = true; + } + } + + private void ensureOpen() { + if (closed) { + throw new IllegalStateException("The BsonArray resources have been released."); + } + } + private class ByteBufBsonArrayIterator implements Iterator { - private final ByteBuf duplicatedByteBuf = byteBuf.duplicate(); - private final BsonBinaryReader bsonReader; + private ByteBuf duplicatedByteBuf; + private BsonBinaryReader bsonReader; + private Closeable resourceHandle; + private boolean finished; { + ensureOpen(); + duplicatedByteBuf = byteBuf.duplicate(); + resourceHandle = () -> { + if (duplicatedByteBuf != null) { + duplicatedByteBuf.release(); + duplicatedByteBuf = null; + } + }; + trackedResources.add(resourceHandle); bsonReader = new BsonBinaryReader(new ByteBufferBsonInput(duplicatedByteBuf)); // While one might expect that this would be a call to BsonReader#readStartArray that doesn't work because BsonBinaryReader // expects to be positioned at the start at the beginning of a document, not an array. Fortunately, a BSON array has exactly @@ -283,7 +342,11 @@ private class ByteBufBsonArrayIterator implements Iterator { @Override public boolean hasNext() { - return bsonReader.getCurrentBsonType() != BsonType.END_OF_DOCUMENT; + boolean hasNext = bsonReader.getCurrentBsonType() != BsonType.END_OF_DOCUMENT; + if (!hasNext) { + cleanup(); + } + return hasNext; } @Override @@ -292,9 +355,22 @@ public BsonValue next() { throw new NoSuchElementException(); } bsonReader.skipName(); - BsonValue value = readBsonValue(duplicatedByteBuf, bsonReader); + BsonValue value = readBsonValue(duplicatedByteBuf, bsonReader, trackedResources); bsonReader.readBsonType(); return value; } + + private void cleanup() { + if (!finished) { + finished = true; + // Remove from tracked resources since we're cleaning up immediately + trackedResources.remove(resourceHandle); + try { + resourceHandle.close(); + } catch (Exception e) { + // Ignore + } + } + } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java index 70ed10a75a8..57b5967cd23 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonDocument.java @@ -16,139 +16,529 @@ package com.mongodb.internal.connection; +import com.mongodb.MongoInternalException; +import com.mongodb.internal.VisibleForTesting; import com.mongodb.lang.Nullable; +import org.bson.BsonArray; import org.bson.BsonBinaryReader; import org.bson.BsonDocument; +import org.bson.BsonReader; import org.bson.BsonType; import org.bson.BsonValue; import org.bson.ByteBuf; -import org.bson.RawBsonDocument; import org.bson.codecs.BsonDocumentCodec; import org.bson.codecs.DecoderContext; import org.bson.io.ByteBufferBsonInput; import org.bson.json.JsonMode; -import org.bson.json.JsonWriter; import org.bson.json.JsonWriterSettings; +import java.io.ByteArrayOutputStream; +import java.io.Closeable; import java.io.InvalidObjectException; import java.io.ObjectInputStream; -import java.io.StringWriter; +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; import java.util.AbstractCollection; import java.util.AbstractMap; import java.util.AbstractSet; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; +import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PACKAGE; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; import static com.mongodb.internal.connection.ByteBufBsonHelper.readBsonValue; +import static java.util.Collections.emptyMap; -final class ByteBufBsonDocument extends BsonDocument { +/** + * A memory-efficient, read-only {@link BsonDocument} implementation backed by a {@link ByteBuf}. + * + *

Overview

+ *

This class provides lazy access to BSON document fields without fully deserializing the document + * into memory. It reads field values directly from the underlying byte buffer on demand, which is + * particularly useful for large documents where only a few fields need to be accessed.

+ * + *

Data Sources

+ *

A {@code ByteBufBsonDocument} can contain data from two sources:

+ *
    + *
  • Body fields: Standard BSON document fields stored in {@link #bodyByteBuf}. These are + * read lazily using a {@link BsonBinaryReader}.
  • + *
  • Sequence fields: MongoDB OP_MSG Type 1 payload sequences stored in {@link #sequenceFields}. + * These are used when parsing command messages that contain document sequences (e.g., bulk inserts). + * Each sequence field appears as an array of documents when accessed.
  • + *
+ * + *

OP_MSG Command Message Support

+ *

The {@link #createCommandMessage(CompositeByteBuf)} factory method parses MongoDB OP_MSG format, + * which consists of:

+ *
    + *
  1. A body section (Type 0): The main command document
  2. + *
  3. Zero or more document sequence sections (Type 1): Arrays of documents identified by field name
  4. + *
+ *

For example, an insert command might have the body containing {@code {insert: "collection", $db: "test"}} + * and a sequence section with field name "documents" containing the documents to insert.

+ * + *

Resource Management

+ *

This class implements {@link Closeable} and manages several types of resources:

+ *
    + *
  • ByteBuf instances: The body buffer and any duplicated buffers created during iteration + * or value access are tracked in {@link #trackedResources} and released on {@link #close()}.
  • + *
  • Nested ByteBufBsonDocument/ByteBufBsonArray: When accessing nested documents or arrays, + * new {@code ByteBufBsonDocument} or {@link ByteBufBsonArray} instances are created. These are + * registered as closeables and closed recursively when the parent is closed.
  • + *
  • Sequence field documents: Documents within sequence fields are also {@code ByteBufBsonDocument} + * instances that are tracked and closed with the parent.
  • + *
+ * + *

Important: Always close this document when done to prevent memory leaks. After closing, + * any operation will throw {@link IllegalStateException}.

+ * + *

Caching Strategy

+ *

The class uses lazy caching to optimize repeated access:

+ *
    + *
  • {@link #cachedDocument}: Once {@link #toBsonDocument()} is called, the fully hydrated document + * is cached and all subsequent operations use this cache. At this point, the underlying buffers + * are released since they're no longer needed.
  • + *
  • {@link #cachedFirstKey}: The first key is cached after the first call to {@link #getFirstKey()}.
  • + *
  • Sequence field arrays are cached within {@link SequenceField} after first access.
  • + *
+ * + *

Immutability

+ *

This class is read-only. All mutation methods ({@link #put}, {@link #remove}, {@link #clear}, etc.) + * throw {@link UnsupportedOperationException}.

+ * + *

Thread Safety

+ *

This class is not thread-safe. Concurrent access from multiple threads requires external synchronization.

+ * + *

Serialization

+ *

Java serialization is supported via {@link #writeReplace()}, which converts this document to a + * regular {@link BsonDocument} before serialization.

+ * + * @see ByteBufBsonArray + * @see ByteBufBsonHelper + */ +public final class ByteBufBsonDocument extends BsonDocument implements Closeable { private static final long serialVersionUID = 2L; - private final transient ByteBuf byteBuf; + /** + * The underlying byte buffer containing the BSON document body. + * This is the main document data, excluding any OP_MSG sequence sections. + * Set to null after {@link #releaseResources()} is called. + */ + private transient ByteBuf bodyByteBuf; + + /** + * Map of sequence field names to their corresponding {@link SequenceField} instances. + * These represent OP_MSG Type 1 payload sections. Each sequence field appears as an + * array when accessed via {@link #get(Object)}. + * Empty for simple documents not created from OP_MSG. + */ + private transient Map sequenceFields; + + /** + * List of resources that need to be closed/released when this document is closed. + * + *

Memory Management Strategy:

+ *
    + *
  • Always tracked: The main bodyByteBuf and any nested ByteBufBsonDocument/ByteBufBsonArray + * instances returned to callers are permanently tracked until this document is closed or + * {@link #toBsonDocument()} caches and releases them.
  • + *
  • Temporarily tracked: Iterator duplicate buffers are tracked during iteration + * but automatically removed and released when iteration completes. This prevents memory accumulation + * from completed iterations while ensuring cleanup if the parent document is closed mid-iteration.
  • + *
  • Not tracked: Short-lived duplicate buffers used in query methods + * (e.g., {@link #findKeyInBody}, {@link #containsKey}) are released immediately in finally blocks + * and never added to this list. Temporary nested documents created during value comparison + * use separate tracking lists.
  • + *
+ */ + private final transient List trackedResources; /** - * Create a list of ByteBufBsonDocument from a buffer positioned at the start of the first document of an OP_MSG Section - * of type Document Sequence (Kind 1). - *

- * The provided buffer will be positioned at the end of the section upon normal completion of the method + * Cached fully-hydrated BsonDocument. Once populated via {@link #toBsonDocument()}, + * all subsequent read operations use this cache instead of reading from the byte buffer. */ - static List createList(final ByteBuf outputByteBuf) { - List documents = new ArrayList<>(); - while (outputByteBuf.hasRemaining()) { - ByteBufBsonDocument curDocument = createOne(outputByteBuf); - documents.add(curDocument); + private transient BsonDocument cachedDocument; + + /** + * Cached first key of the document. Populated on first call to {@link #getFirstKey()}. + */ + private transient String cachedFirstKey; + + /** + * Flag indicating whether this document has been closed. + * Once closed, all operations throw {@link IllegalStateException}. + */ + private transient boolean closed; + + + /** + * Creates a {@code ByteBufBsonDocument} from an OP_MSG command message. + * + *

This factory method parses the MongoDB OP_MSG wire protocol format, which consists of:

+ *
    + *
  1. Body section (Type 0): A single BSON document containing the command
  2. + *
  3. Document sequence sections (Type 1): Zero or more sections, each containing + * a field identifier and a sequence of BSON documents
  4. + *
+ * + *

The sequence sections are stored in {@link #sequenceFields} and appear as array fields + * when the document is accessed. For example, an insert command's "documents" sequence + * will appear as an array when calling {@code get("documents")}.

+ * + *

Wire Format Parsed

+ *
+     * [body document bytes]
+     * [section type: 1 byte] [section size: 4 bytes] [identifier: cstring] [document bytes...]
+     * ... (more sections)
+     * 
+ * + * @param commandMessageByteBuf The composite buffer positioned at the start of the body document. + * Position will be advanced past all parsed sections. + * @return A new {@code ByteBufBsonDocument} representing the command with any sequence fields. + */ + @VisibleForTesting(otherwise = PRIVATE) + public static ByteBufBsonDocument createCommandMessage(final CompositeByteBuf commandMessageByteBuf) { + // Parse body document: read size, create a view of just the body bytes + int bodyStart = commandMessageByteBuf.position(); + int bodySizeInBytes = commandMessageByteBuf.getInt(); + int bodyEnd = bodyStart + bodySizeInBytes; + ByteBuf bodyByteBuf = commandMessageByteBuf.duplicate().position(bodyStart).limit(bodyEnd); + + List trackedResources = new ArrayList<>(); + commandMessageByteBuf.position(bodyEnd); + + // Parse any Type 1 (document sequence) sections that follow the body + Map sequences = new LinkedHashMap<>(); + while (commandMessageByteBuf.hasRemaining()) { + // Skip section type byte (we only support Type 1 here) + commandMessageByteBuf.position(commandMessageByteBuf.position() + 1); + + // Read section size and calculate bounds + int sequenceStart = commandMessageByteBuf.position(); + int sequenceSizeInBytes = commandMessageByteBuf.getInt(); + int sectionEnd = sequenceStart + sequenceSizeInBytes; + + // Read the field identifier (null-terminated string) + String fieldName = readCString(commandMessageByteBuf); + assertFalse(fieldName.contains(".")); + + // Create a view of just the document sequence bytes (after the identifier) + ByteBuf sequenceByteBuf = commandMessageByteBuf.duplicate(); + sequenceByteBuf.position(commandMessageByteBuf.position()).limit(sectionEnd); + sequences.put(fieldName, new SequenceField(sequenceByteBuf, trackedResources)); + commandMessageByteBuf.position(sectionEnd); } - return documents; + return new ByteBufBsonDocument(bodyByteBuf, trackedResources, sequences); } /** - * Create a ByteBufBsonDocument from a buffer positioned at the start of a BSON document. - * The provided buffer will be positioned at the end of the document upon normal completion of the method + * Creates a simple {@code ByteBufBsonDocument} from a byte buffer containing a single BSON document. + * + *

Use this constructor for standard BSON documents. For OP_MSG command messages with + * document sequences, use {@link #createCommandMessage(CompositeByteBuf)} instead.

+ * + * @param byteBuf The buffer containing the BSON document. The buffer should be positioned + * at the start of the document and contain the complete document bytes. + */ + @VisibleForTesting(otherwise = PACKAGE) + public ByteBufBsonDocument(final ByteBuf byteBuf) { + this(byteBuf, new ArrayList<>(), new HashMap<>()); + } + + /** + * Private constructor used by factory methods. + * + * @param bodyByteBuf The buffer containing the body document bytes + * @param trackedResources Mutable list for tracking resources to close + * @param sequenceFields Map of sequence field names to their data (empty for simple documents) */ - static ByteBufBsonDocument createOne(final ByteBuf outputByteBuf) { - int documentStart = outputByteBuf.position(); - int documentSizeInBytes = outputByteBuf.getInt(); - int documentEnd = documentStart + documentSizeInBytes; - ByteBuf slice = outputByteBuf.duplicate().position(documentStart).limit(documentEnd); - outputByteBuf.position(documentEnd); - return new ByteBufBsonDocument(slice); + private ByteBufBsonDocument(final ByteBuf bodyByteBuf, final List trackedResources, + final Map sequenceFields) { + this.bodyByteBuf = bodyByteBuf; + this.trackedResources = trackedResources; + this.sequenceFields = sequenceFields; + trackedResources.add(bodyByteBuf::release); } + // ==================== Size and Empty Checks ==================== + @Override - public String toJson() { - return toJson(JsonWriterSettings.builder().outputMode(JsonMode.RELAXED).build()); + public int size() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.size(); + } + // Total size = body fields + sequence fields + return countBodyFields() + sequenceFields.size(); } @Override - public String toJson(final JsonWriterSettings settings) { - StringWriter stringWriter = new StringWriter(); - JsonWriter jsonWriter = new JsonWriter(stringWriter, settings); - ByteBuf duplicate = byteBuf.duplicate(); - try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(duplicate))) { - jsonWriter.pipe(reader); - return stringWriter.toString(); - } finally { - duplicate.release(); + public boolean isEmpty() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.isEmpty(); } + return !hasBodyFields() && sequenceFields.isEmpty(); } + // ==================== Key/Value Lookups ==================== + @Override - public BsonBinaryReader asBsonReader() { - return new BsonBinaryReader(new ByteBufferBsonInput(byteBuf.duplicate())); + public boolean containsKey(final Object key) { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.containsKey(key); + } + if (key == null) { + throw new IllegalArgumentException("key can not be null"); + } + // Check sequence fields first (fast HashMap lookup), then scan body + if (sequenceFields.containsKey(key)) { + return true; + } + return findKeyInBody((String) key); } - @SuppressWarnings("MethodDoesntCallSuperMethod") @Override - public BsonDocument clone() { - byte[] clonedBytes = new byte[byteBuf.remaining()]; - byteBuf.get(byteBuf.position(), clonedBytes); - return new RawBsonDocument(clonedBytes); + public boolean containsValue(final Object value) { + ensureOpen(); + if (!(value instanceof BsonValue)) { + return false; + } + + if (cachedDocument != null) { + return cachedDocument.containsValue(value); + } + + // Search body fields first, then sequence fields + if (findValueInBody((BsonValue) value)) { + return true; + } + for (SequenceField field : sequenceFields.values()) { + if (field.containsValue(value)) { + return true; + } + } + return false; } + /** + * {@inheritDoc} + * + *

For sequence fields (OP_MSG document sequences), returns a {@link BsonArray} containing + * {@code ByteBufBsonDocument} instances for each document in the sequence.

+ */ @Nullable - T findInDocument(final Finder finder) { - ByteBuf duplicateByteBuf = byteBuf.duplicate(); - try (BsonBinaryReader bsonReader = new BsonBinaryReader(new ByteBufferBsonInput(duplicateByteBuf))) { - bsonReader.readStartDocument(); - while (bsonReader.readBsonType() != BsonType.END_OF_DOCUMENT) { - T found = finder.find(duplicateByteBuf, bsonReader); - if (found != null) { - return found; - } + @Override + public BsonValue get(final Object key) { + ensureOpen(); + notNull("key", key); + + if (!(key instanceof String)) { + return null; + } + if (cachedDocument != null) { + return cachedDocument.get(key); + } + + // Check sequence fields first, then body + if (sequenceFields.containsKey(key)) { + return sequenceFields.get(key).asArray(); + } + return getValueFromBody((String) key); + } + + @Override + public String getFirstKey() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.getFirstKey(); + } + if (cachedFirstKey != null) { + return cachedFirstKey; + } + cachedFirstKey = getFirstKeyFromBody(); + return assertNotNull(cachedFirstKey); + } + + // ==================== Collection Views ==================== + // These return lazy views that iterate over both body and sequence fields + + @Override + public Set> entrySet() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.entrySet(); + } + return new AbstractSet>() { + @Override + public Iterator> iterator() { + // Combine body entries with sequence entries + return new CombinedIterator<>(createBodyIterator(IteratorMode.ENTRIES), createSequenceEntryIterator()); } - bsonReader.readEndDocument(); - } finally { - duplicateByteBuf.release(); + + @Override + public int size() { + return ByteBufBsonDocument.this.size(); + } + }; + } + + @Override + public Collection values() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.values(); } + return new AbstractCollection() { + @Override + public Iterator iterator() { + return new CombinedIterator<>(createBodyIterator(IteratorMode.VALUES), createSequenceValueIterator()); + } - return finder.notFound(); + @Override + public int size() { + return ByteBufBsonDocument.this.size(); + } + }; } - BsonDocument toBaseBsonDocument() { - ByteBuf duplicateByteBuf = byteBuf.duplicate(); - try (BsonBinaryReader bsonReader = new BsonBinaryReader(new ByteBufferBsonInput(duplicateByteBuf))) { - return new BsonDocumentCodec().decode(bsonReader, DecoderContext.builder().build()); - } finally { - duplicateByteBuf.release(); + @Override + public Set keySet() { + ensureOpen(); + if (cachedDocument != null) { + return cachedDocument.keySet(); + } + return new AbstractSet() { + @Override + public Iterator iterator() { + return new CombinedIterator<>(createBodyIterator(IteratorMode.KEYS), sequenceFields.keySet().iterator()); + } + + @Override + public int size() { + return ByteBufBsonDocument.this.size(); + } + }; + } + + // ==================== Conversion Methods ==================== + + @Override + public BsonReader asBsonReader() { + ensureOpen(); + // Must hydrate first since we need to include sequence fields + return toBsonDocument().asBsonReader(); + } + + /** + * Converts this document to a regular {@link BsonDocument}, fully deserializing all data. + * + *

After this method is called:

+ *
    + *
  • The result is cached for future calls
  • + *
  • All underlying byte buffers are released
  • + *
  • Sequence field documents are hydrated to regular {@code BsonDocument} instances
  • + *
  • All subsequent read operations use the cached document
  • + *
+ * + * @return A fully materialized {@link BsonDocument} containing all fields + */ + @Override + public BsonDocument toBsonDocument() { + ensureOpen(); + if (cachedDocument == null) { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + // Decode body document + BsonDocument doc = new BsonDocumentCodec().decode(reader, DecoderContext.builder().build()); + // Add hydrated sequence fields + for (Map.Entry entry : sequenceFields.entrySet()) { + doc.put(entry.getKey(), entry.getValue().toHydratedArray()); + } + cachedDocument = doc; + // Release buffers since we no longer need them + releaseResources(); + } finally { + dup.release(); + } } + return cachedDocument; } - ByteBufBsonDocument(final ByteBuf byteBuf) { - this.byteBuf = byteBuf; + @Override + public String toJson() { + return toJson(JsonWriterSettings.builder().outputMode(JsonMode.RELAXED).build()); } @Override - public void clear() { - throw new UnsupportedOperationException("ByteBufBsonDocument instances are immutable"); + public String toJson(final JsonWriterSettings settings) { + ensureOpen(); + return toBsonDocument().toJson(settings); } + @Override + public String toString() { + ensureOpen(); + return toBsonDocument().toString(); + } + + @SuppressWarnings("MethodDoesntCallSuperMethod") + @Override + public BsonDocument clone() { + ensureOpen(); + return toBsonDocument().clone(); + } + + @SuppressWarnings("EqualsDoesntCheckParameterClass") + @Override + public boolean equals(final Object o) { + ensureOpen(); + return toBsonDocument().equals(o); + } + + @Override + public int hashCode() { + ensureOpen(); + return toBsonDocument().hashCode(); + } + + // ==================== Resource Management ==================== + + /** + * Releases all resources held by this document. + * + *

This includes:

+ *
    + *
  • Releasing all tracked {@link ByteBuf} instances
  • + *
  • Closing all nested {@code ByteBufBsonDocument} and {@link ByteBufBsonArray} instances
  • + *
  • Clearing internal references
  • + *
+ * + *

After calling this method, any operation on this document will throw + * {@link IllegalStateException}. This method is idempotent.

+ */ + @Override + public void close() { + if (!closed) { + closed = true; + releaseResources(); + } + } + + // ==================== Mutation Methods (Unsupported) ==================== + @Override public BsonValue put(final String key, final BsonValue value) { throw new UnsupportedOperationException("ByteBufBsonDocument instances are immutable"); @@ -170,260 +560,476 @@ public BsonValue remove(final Object key) { } @Override - public boolean isEmpty() { - return assertNotNull(findInDocument(new Finder() { - @Override - public Boolean find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - return false; - } - - @Override - public Boolean notFound() { - return true; - } - })); + public void clear() { + throw new UnsupportedOperationException("ByteBufBsonDocument instances are immutable"); } - @Override - public int size() { - return assertNotNull(findInDocument(new Finder() { - private int size; + // ==================== Private Body Field Operations ==================== + // These methods read from bodyByteBuf using a temporary duplicate buffer - @Override - @Nullable - public Integer find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - size++; - bsonReader.readName(); - bsonReader.skipValue(); - return null; + /** + * Searches the body for a field with the given key. + * Uses a duplicated buffer to avoid modifying the original position. + */ + private boolean findKeyInBody(final String key) { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + if (reader.readName().equals(key)) { + return true; + } + reader.skipValue(); } + return false; + } finally { + dup.release(); + } + } - @Override - public Integer notFound() { - return size; + /** + * Searches the body for a field with the given value. + * Creates ByteBufBsonDocument/ByteBufBsonArray for nested structures during comparison or vanilla BsonValues. + * Uses temporary tracking list to avoid polluting the main trackedResources with short-lived objects. + */ + private boolean findValueInBody(final BsonValue targetValue) { + ByteBuf dup = bodyByteBuf.duplicate(); + List tempTrackedResources = new ArrayList<>(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + reader.skipName(); + if (readBsonValue(dup, reader, tempTrackedResources).equals(targetValue)) { + return true; + } } - })); + return false; + } finally { + // Release temporary resources created during comparison + for (Closeable resource : tempTrackedResources) { + try { + resource.close(); + } catch (Exception e) { + // Continue closing other resources + } + } + dup.release(); + } } - @Override - public Set> entrySet() { - return new ByteBufBsonDocumentEntrySet(); + /** + * Retrieves a value from the body by key. + * Returns null if the key is not found in the body. + */ + @Nullable + private BsonValue getValueFromBody(final String key) { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + if (reader.readName().equals(key)) { + return readBsonValue(dup, reader, trackedResources); + } + reader.skipValue(); + } + return null; + } finally { + dup.release(); + } } - @Override - public Collection values() { - return new ByteBufBsonDocumentValuesCollection(); + /** + * Gets the first key from the body, or from sequence fields if body is empty. + * Throws NoSuchElementException if the document is completely empty. + */ + private String getFirstKeyFromBody() { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + if (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + return reader.readName(); + } + // Body is empty, try sequence fields + if (!sequenceFields.isEmpty()) { + return sequenceFields.keySet().iterator().next(); + } + throw new NoSuchElementException(); + } finally { + dup.release(); + } } - @Override - public Set keySet() { - return new ByteBufBsonDocumentKeySet(); + /** + * Checks if the body contains at least one field. + */ + private boolean hasBodyFields() { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + return reader.readBsonType() != BsonType.END_OF_DOCUMENT; + } finally { + dup.release(); + } } - @Override - public boolean containsKey(final Object key) { - if (key == null) { - throw new IllegalArgumentException("key can not be null"); + /** + * Counts the number of fields in the body document. + */ + private int countBodyFields() { + ByteBuf dup = bodyByteBuf.duplicate(); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(dup))) { + reader.readStartDocument(); + int count = 0; + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + count++; + reader.skipName(); + reader.skipValue(); + } + return count; + } finally { + dup.release(); } + } + + // ==================== Iterator Support ==================== + + /** + * Mode for the body iterator, determining what type of elements it produces. + */ + private enum IteratorMode { ENTRIES, KEYS, VALUES } + + /** + * Creates an iterator over the body document fields. + * + *

The iterator creates a duplicated ByteBuf that is temporarily tracked for safety. + * When iteration completes normally, the buffer is released immediately and removed from tracking. + * This prevents accumulation of finished iterator buffers while ensuring cleanup if the parent + * document is closed before iteration completes.

+ * + * @param mode Determines whether to return entries, keys, or values + * @return An iterator of the appropriate type + */ + @SuppressWarnings("unchecked") + private Iterator createBodyIterator(final IteratorMode mode) { + return new Iterator() { + private ByteBuf duplicatedByteBuf; + private BsonBinaryReader reader; + private Closeable resourceHandle; + private boolean started; + private boolean finished; + + { + // Create duplicate buffer for iteration and track it temporarily + duplicatedByteBuf = bodyByteBuf.duplicate(); + resourceHandle = () -> { + if (duplicatedByteBuf != null) { + try { + if (reader != null) { + reader.close(); + } + } catch (Exception e) { + // Ignore + } + duplicatedByteBuf.release(); + duplicatedByteBuf = null; + reader = null; + } + }; + trackedResources.add(resourceHandle); + reader = new BsonBinaryReader(new ByteBufferBsonInput(duplicatedByteBuf)); + } - Boolean containsKey = findInDocument(new Finder() { @Override - public Boolean find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - if (bsonReader.readName().equals(key)) { - return true; + public boolean hasNext() { + if (finished) { + return false; } - bsonReader.skipValue(); - return null; + if (!started) { + reader.readStartDocument(); + reader.readBsonType(); + started = true; + } + boolean hasNext = reader.getCurrentBsonType() != BsonType.END_OF_DOCUMENT; + if (!hasNext) { + cleanup(); + } + return hasNext; } @Override - public Boolean notFound() { - return false; + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + String key = reader.readName(); + BsonValue value = readBsonValue(duplicatedByteBuf, reader, trackedResources); + reader.readBsonType(); + + switch (mode) { + case ENTRIES: + return (T) new AbstractMap.SimpleImmutableEntry<>(key, value); + case KEYS: + return (T) key; + case VALUES: + return (T) value; + default: + throw new IllegalStateException("Unknown iterator mode: " + mode); + } } - }); - return containsKey != null ? containsKey : false; + + private void cleanup() { + if (!finished) { + finished = true; + // Remove from tracked resources since we're cleaning up immediately + trackedResources.remove(resourceHandle); + try { + resourceHandle.close(); + } catch (Exception e) { + // Ignore + } + } + } + }; } - @Override - public boolean containsValue(final Object value) { - Boolean containsValue = findInDocument(new Finder() { + /** + * Creates an iterator over sequence fields as map entries. + * Each entry contains the field name and its array value. + */ + private Iterator> createSequenceEntryIterator() { + Iterator> iter = sequenceFields.entrySet().iterator(); + return new Iterator>() { @Override - public Boolean find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - bsonReader.skipName(); - if (readBsonValue(byteBuf, bsonReader).equals(value)) { - return true; - } - return null; + public boolean hasNext() { + return iter.hasNext(); } @Override - public Boolean notFound() { - return false; + public Entry next() { + Map.Entry entry = iter.next(); + return new AbstractMap.SimpleImmutableEntry<>(entry.getKey(), entry.getValue().asArray()); } - }); - return containsValue != null ? containsValue : false; + }; } - @Nullable - @Override - public BsonValue get(final Object key) { - notNull("key", key); - return findInDocument(new Finder() { + /** + * Creates an iterator over sequence field values (arrays). + */ + private Iterator createSequenceValueIterator() { + Iterator iter = sequenceFields.values().iterator(); + return new Iterator() { @Override - public BsonValue find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - if (bsonReader.readName().equals(key)) { - return readBsonValue(byteBuf, bsonReader); - } - bsonReader.skipValue(); - return null; + public boolean hasNext() { + return iter.hasNext(); } - @Nullable @Override - public BsonValue notFound() { - return null; + public BsonValue next() { + return iter.next().asArray(); } - }); + }; } + // ==================== Resource Management Helpers ==================== + /** - * Gets the first key in this document. + * Releases all tracked resources and clears internal state. * - * @return the first key in this document - * @throws java.util.NoSuchElementException if the document is empty + *

Called by {@link #close()} and after {@link #toBsonDocument()} caches the result. + * Resources include ByteBuf instances and nested ByteBufBsonDocument/ByteBufBsonArray.

*/ - public String getFirstKey() { - return assertNotNull(findInDocument(new Finder() { - @Override - public String find(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { - return bsonReader.readName(); + private void releaseResources() { + for (Closeable resource : trackedResources) { + try { + resource.close(); + } catch (Exception e) { + // Log and continue closing other resources } + } - @Override - public String notFound() { - throw new NoSuchElementException(); - } - })); + assertTrue(bodyByteBuf == null || bodyByteBuf.getReferenceCount() == 0, "Failed to release all `bodyByteBuf` resources"); + assertTrue(sequenceFields.values().stream().allMatch(b -> b.sequenceByteBuf.getReferenceCount() == 0), + "Failed to release all `sequenceField` resources"); + + trackedResources.clear(); + sequenceFields = emptyMap(); + bodyByteBuf = null; + cachedFirstKey = null; } - private interface Finder { - @Nullable - T find(ByteBuf byteBuf, BsonBinaryReader bsonReader); - @Nullable - T notFound(); + /** + * Throws IllegalStateException if this document has been closed. + */ + private void ensureOpen() { + if (closed) { + throw new IllegalStateException("The BsonDocument resources have been released."); + } + } + + // ==================== Utility Methods ==================== + + /** + * Reads a null-terminated C-string from the buffer. + * Used for parsing OP_MSG sequence identifiers. + */ + private static String readCString(final ByteBuf byteBuf) { + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + byte b = byteBuf.get(); + while (b != 0) { + bytes.write(b); + b = byteBuf.get(); + } + try { + return bytes.toString(StandardCharsets.UTF_8.name()); + } catch (UnsupportedEncodingException e) { + throw new MongoInternalException("Unexpected exception", e); + } } - // see https://docs.oracle.com/javase/6/docs/platform/serialization/spec/output.html + /** + * Serialization support: converts to a regular BsonDocument before serialization. + */ private Object writeReplace() { - return toBaseBsonDocument(); + ensureOpen(); + return toBsonDocument(); } - // see https://docs.oracle.com/javase/6/docs/platform/serialization/spec/input.html private void readObject(final ObjectInputStream stream) throws InvalidObjectException { throw new InvalidObjectException("Proxy required"); } - private class ByteBufBsonDocumentEntrySet extends AbstractSet> { - @Override - public Iterator> iterator() { - return new Iterator>() { - private final ByteBuf duplicatedByteBuf = byteBuf.duplicate(); - private final BsonBinaryReader bsonReader; - - { - bsonReader = new BsonBinaryReader(new ByteBufferBsonInput(duplicatedByteBuf)); - bsonReader.readStartDocument(); - bsonReader.readBsonType(); - } - - @Override - public boolean hasNext() { - return bsonReader.getCurrentBsonType() != BsonType.END_OF_DOCUMENT; - } + // ==================== Inner Classes ==================== - @Override - public Entry next() { - if (!hasNext()) { - throw new NoSuchElementException(); - } - String key = bsonReader.readName(); - BsonValue value = readBsonValue(duplicatedByteBuf, bsonReader); - bsonReader.readBsonType(); - return new AbstractMap.SimpleEntry<>(key, value); - } + /** + * Represents an OP_MSG Type 1 document sequence section. + * + *

A sequence field contains a contiguous series of BSON documents in the buffer. + * When accessed via {@link #asArray()}, it returns a {@link BsonArray} containing + * {@link ByteBufBsonDocument} instances for each document.

+ * + *

The documents are lazily parsed on first access and cached for subsequent calls.

+ */ + private static final class SequenceField { + /** Buffer containing the sequence of BSON documents */ + private final ByteBuf sequenceByteBuf; - }; - } + /** Reference to parent's tracked resources for registering created documents */ + private final List trackedResources; - @Override - public boolean isEmpty() { - return !iterator().hasNext(); - } + /** Cached list of parsed documents, populated on first access */ + private List documents; - @Override - public int size() { - return ByteBufBsonDocument.this.size(); + SequenceField(final ByteBuf sequenceByteBuf, final List trackedResources) { + this.sequenceByteBuf = sequenceByteBuf; + this.trackedResources = trackedResources; + trackedResources.add(sequenceByteBuf::release); } - } - private class ByteBufBsonDocumentKeySet extends AbstractSet { - @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") - private final Set> entrySet = new ByteBufBsonDocumentEntrySet(); - - @Override - public Iterator iterator() { - final Iterator> entrySetIterator = entrySet.iterator(); - return new Iterator() { - @Override - public boolean hasNext() { - return entrySetIterator.hasNext(); - } + /** + * Returns this sequence as a BsonArray of ByteBufBsonDocument instances. + * + *

On first call, parses the buffer to create ByteBufBsonDocument for each + * document and registers them with the parent's tracked resources.

+ * + * @return A BsonArray containing the sequence documents + */ + BsonValue asArray() { + if (documents == null) { + documents = new ArrayList<>(); + ByteBuf dup = sequenceByteBuf.duplicate(); + try { + while (dup.hasRemaining()) { + // Read document size to determine bounds + int docStart = dup.position(); + int docSize = dup.getInt(); + int docEnd = docStart + docSize; - @Override - public String next() { - return entrySetIterator.next().getKey(); + // Create a view of just this document's bytes + ByteBuf docBuf = sequenceByteBuf.duplicate().position(docStart).limit(docEnd); + ByteBufBsonDocument doc = new ByteBufBsonDocument(docBuf); + // Track for cleanup when parent is closed + trackedResources.add(doc); + documents.add(doc); + dup.position(docEnd); + } + } finally { + dup.release(); } - }; + } + // Return a new array each time to prevent external modification of cached list + return new BsonArray(new ArrayList<>(documents)); } - @Override - public boolean isEmpty() { - return entrySet.isEmpty(); + /** + * Checks if this sequence contains the given value. + */ + boolean containsValue(final Object value) { + return value instanceof BsonValue && asArray().asArray().contains(value); } - @Override - public int size() { - return entrySet.size(); + /** + * Converts this sequence to a BsonArray of regular BsonDocument instances. + * + *

Used by {@link ByteBufBsonDocument#toBsonDocument()} to fully hydrate the document. + * Unlike {@link #asArray()}, this creates regular BsonDocument instances, not + * ByteBufBsonDocument wrappers.

+ * + * @return A BsonArray containing fully deserialized BsonDocument instances + */ + BsonArray toHydratedArray() { + ByteBuf dup = sequenceByteBuf.duplicate(); + try { + List hydratedDocs = new ArrayList<>(); + while (dup.hasRemaining()) { + int docStart = dup.position(); + int docSize = dup.getInt(); + int docEnd = docStart + docSize; + ByteBuf docBuf = dup.duplicate().position(docStart).limit(docEnd); + try (BsonBinaryReader reader = new BsonBinaryReader(new ByteBufferBsonInput(docBuf))) { + hydratedDocs.add(new BsonDocumentCodec().decode(reader, DecoderContext.builder().build())); + } finally { + docBuf.release(); + } + dup.position(docEnd); + } + return new BsonArray(hydratedDocs); + } finally { + dup.release(); + } } } - private class ByteBufBsonDocumentValuesCollection extends AbstractCollection { - @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") - private final Set> entrySet = new ByteBufBsonDocumentEntrySet(); - - @Override - public Iterator iterator() { - final Iterator> entrySetIterator = entrySet.iterator(); - return new Iterator() { - @Override - public boolean hasNext() { - return entrySetIterator.hasNext(); - } + /** + * An iterator that combines two iterators sequentially. + * + *

Used to merge body field iteration with sequence field iteration, + * presenting a unified view of all document fields.

+ * + * @param The type of elements returned by the iterator + */ + private static final class CombinedIterator implements Iterator { + private final Iterator primary; + private final Iterator secondary; - @Override - public BsonValue next() { - return entrySetIterator.next().getValue(); - } - }; + CombinedIterator(final Iterator primary, final Iterator secondary) { + this.primary = primary; + this.secondary = secondary; } @Override - public boolean isEmpty() { - return entrySet.isEmpty(); + public boolean hasNext() { + return primary.hasNext() || secondary.hasNext(); } + @Override - public int size() { - return entrySet.size(); + public T next() { + if (primary.hasNext()) { + return primary.next(); + } + if (secondary.hasNext()) { + return secondary.next(); + } + throw new NoSuchElementException(); } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonHelper.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonHelper.java index 55054112bf2..4d4d4846afa 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonHelper.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufBsonHelper.java @@ -38,18 +38,25 @@ import org.bson.codecs.BsonDocumentCodec; import org.bson.codecs.DecoderContext; +import java.io.Closeable; +import java.util.List; + final class ByteBufBsonHelper { - static BsonValue readBsonValue(final ByteBuf byteBuf, final BsonBinaryReader bsonReader) { + static BsonValue readBsonValue(final ByteBuf byteBuf, final BsonBinaryReader bsonReader, final List trackedResources) { BsonValue value; switch (bsonReader.getCurrentBsonType()) { case DOCUMENT: ByteBuf documentByteBuf = byteBuf.duplicate(); - value = new ByteBufBsonDocument(documentByteBuf); + ByteBufBsonDocument document = new ByteBufBsonDocument(documentByteBuf); + trackedResources.add(document); + value = document; bsonReader.skipValue(); break; case ARRAY: ByteBuf arrayByteBuf = byteBuf.duplicate(); - value = new ByteBufBsonArray(arrayByteBuf); + ByteBufBsonArray array = new ByteBufBsonArray(arrayByteBuf); + trackedResources.add(array); + value = array; bsonReader.skipValue(); break; case INT32: diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java index 1edbb0f4c2f..0988679f28d 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java @@ -16,6 +16,9 @@ package com.mongodb.internal.connection; +import com.mongodb.annotations.Sealed; +import com.mongodb.internal.ResourceUtil; +import com.mongodb.internal.VisibleForTesting; import org.bson.BsonSerializationException; import org.bson.ByteBuf; import org.bson.io.OutputBuffer; @@ -28,11 +31,28 @@ import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; import static java.lang.String.format; /** + * A BSON output implementation that uses pooled {@link ByteBuf} instances for efficient memory management. + * + *

ByteBuf Ownership and Lifecycle

+ *

This class manages the lifecycle of {@link ByteBuf} instances obtained from the {@link BufferProvider}. + * The ownership model is as follows:

+ *
    + *
  • Internal buffers are owned by this output and released when {@link #close()} is called or + * when {@link #truncateToPosition(int)} removes them.
  • + *
  • Methods that return {@link ByteBuf} instances (e.g., {@link #getByteBuffers()}) return + * duplicates with their own reference counts. Callers are responsible for releasing + * these buffers to prevent memory leaks.
  • + *
  • The {@link Branch} subclass merges its buffers into the parent on close, transferring + * ownership by retaining buffers before the branch releases them.
  • + *
+ * *

This class is not part of the public API and may be removed or changed at any time

*/ +@Sealed public class ByteBufferBsonOutput extends OutputBuffer { private static final int MAX_SHIFT = 31; @@ -50,6 +70,9 @@ public class ByteBufferBsonOutput extends OutputBuffer { /** * Construct an instance that uses the given buffer provider to allocate byte buffers as needs as it grows. * + *

The buffer provider is used to allocate new {@link ByteBuf} instances as the output grows. + * All allocated buffers are owned by this output and will be released when {@link #close()} is called.

+ * * @param bufferProvider the non-null buffer provider */ public ByteBufferBsonOutput(final BufferProvider bufferProvider) { @@ -63,6 +86,10 @@ public ByteBufferBsonOutput(final BufferProvider bufferProvider) { * If multiple branches are created, they are merged in the order they are {@linkplain ByteBufferBsonOutput.Branch#close() closed}. * {@linkplain #close() Closing} this {@link ByteBufferBsonOutput} does not {@linkplain ByteBufferBsonOutput.Branch#close() close} the branch. * + *

ByteBuf Ownership: The branch allocates its own buffers. When the branch is closed, + * ownership of these buffers is transferred to the parent by retaining them before the branch releases + * its references. The parent then becomes responsible for releasing these buffers when it is closed.

+ * * @return A new {@link ByteBufferBsonOutput.Branch}. */ public ByteBufferBsonOutput.Branch branch() { @@ -223,10 +250,28 @@ protected void write(final int absolutePosition, final int value) { byteBuffer.put(bufferPositionPair.position++, (byte) value); } + /** + * Returns a list of duplicated byte buffers containing the written data, flipped for reading. + * + *

ByteBuf Ownership: The returned buffers are duplicates with their own + * reference counts (each starts with a reference count of 1). The caller is responsible + * for releasing each buffer when done to prevent memory leaks. Example usage:

+ *
{@code
+     * List buffers = output.getByteBuffers();
+     * try {
+     *     // use buffers
+     * } finally {
+     *     ResourceUtil.release(buffers);
+     * }
+     * }
+ *

Note: These buffers must be released before this {@code ByteBufferBsonOutput} is closed. + * Otherwise there is a risk of the buffers being released back to the bufferProvider and data corruption.

+ * + * @return a list of duplicated buffers, flipped for reading + */ @Override public List getByteBuffers() { ensureOpen(); - List buffers = new ArrayList<>(bufferList.size()); for (final ByteBuf cur : bufferList) { buffers.add(cur.duplicate().order(ByteOrder.LITTLE_ENDIAN).flip()); @@ -234,6 +279,17 @@ public List getByteBuffers() { return buffers; } + /** + * Returns a list of duplicated byte buffers without flipping them. + * + *

ByteBuf Ownership: The returned buffers are duplicates with their own + * reference counts (each starts with a reference count of 1). The caller is responsible + * for releasing each buffer when done to prevent memory leaks.

+ * + * @return a list of duplicated buffers + * @see #getByteBuffers() + */ + @VisibleForTesting(otherwise = PRIVATE) public List getDuplicateByteBuffers() { ensureOpen(); @@ -245,6 +301,13 @@ public List getDuplicateByteBuffers() { } + /** + * {@inheritDoc} + * + *

ByteBuf Management: This method obtains duplicated buffers via + * {@link #getByteBuffers()} and releases them after writing to the output stream, + * ensuring no buffer leaks occur.

+ */ @Override public int pipe(final OutputStream out) throws IOException { ensureOpen(); @@ -263,11 +326,20 @@ public int pipe(final OutputStream out) throws IOException { total += cur.limit(); } } finally { - byteBuffers.forEach(ByteBuf::release); + ResourceUtil.release(byteBuffers); } return total; } + /** + * Truncates this output to the specified position, releasing any buffers that are no longer needed. + * + *

ByteBuf Management: Any buffers beyond the new position are removed from + * the internal buffer list and released. This ensures no memory leaks when truncating.

+ * + * @param newPosition the new position to truncate to + * @throws IllegalArgumentException if newPosition is negative or greater than the current position + */ @Override public void truncateToPosition(final int newPosition) { ensureOpen(); @@ -306,13 +378,15 @@ public final void flush() throws IOException { * {@inheritDoc} *

* Idempotent.

+ * + *

ByteBuf Management: Releases internal buffers and clears the buffer list. + * After this method returns, all buffers that were allocated by this output will have been fully released + * back to the buffer provider.

*/ @Override public void close() { if (isOpen()) { - for (final ByteBuf cur : bufferList) { - cur.release(); - } + ResourceUtil.release(bufferList); currentByteBuffer = null; bufferList.clear(); closed = true; @@ -345,7 +419,14 @@ boolean isOpen() { } /** - * @see #branch() + * Merges a branch's buffers into this output. + * + *

ByteBuf Ownership: This method retains each buffer from the branch before + * adding it to this output's buffer list. This is necessary because the branch will release its + * references when it closes. The retain ensures the buffers remain valid and are now owned by + * this output.

+ * + * @param branch the branch to merge */ private void merge(final ByteBufferBsonOutput branch) { assertTrue(branch instanceof ByteBufferBsonOutput.Branch); @@ -356,6 +437,20 @@ private void merge(final ByteBufferBsonOutput branch) { currentByteBuffer = null; } + /** + * A branch of a {@link ByteBufferBsonOutput} that can be merged back into its parent. + * + *

ByteBuf Ownership: A branch allocates its own buffers independently. + * When {@link #close()} is called:

+ *
    + *
  1. The parent's {@link ByteBufferBsonOutput#merge(ByteBufferBsonOutput)} method is called, + * which retains all buffers in this branch.
  2. + *
  3. Then {@code super.close()} is called, which releases the branch's references to the buffers.
  4. + *
+ *

The retain/release sequence ensures buffers are safely transferred to the parent without leaks.

+ * + * @see #branch() + */ public static final class Branch extends ByteBufferBsonOutput { private final ByteBufferBsonOutput parent; @@ -365,6 +460,16 @@ private Branch(final ByteBufferBsonOutput parent) { } /** + * Closes this branch and merges its data into the parent output. + * + *

ByteBuf Ownership: On close, this branch's buffers are transferred + * to the parent. The parent retains the buffers (incrementing reference counts), and then + * this branch releases only its own single reference. The parent + * becomes the sole owner of the buffers and is responsible for releasing them.

+ * + *

Idempotent. If already closed, this method does nothing.

+ * + * @throws AssertionError if the parent has been closed before this branch * @see #branch() */ @Override diff --git a/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java b/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java index 348349fd18c..f3c12802173 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CommandMessage.java @@ -23,11 +23,11 @@ import com.mongodb.ServerApi; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.internal.MongoNamespaceHelper; +import com.mongodb.internal.ResourceUtil; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.connection.MessageSequences.EmptyMessageSequences; import com.mongodb.internal.session.SessionContext; import com.mongodb.lang.Nullable; -import org.bson.BsonArray; import org.bson.BsonBinaryWriter; import org.bson.BsonBoolean; import org.bson.BsonDocument; @@ -61,8 +61,6 @@ import static com.mongodb.internal.connection.BsonWriterHelper.encodeUsingRegistry; import static com.mongodb.internal.connection.BsonWriterHelper.writeDocumentsOfDualMessageSequences; import static com.mongodb.internal.connection.BsonWriterHelper.writePayload; -import static com.mongodb.internal.connection.ByteBufBsonDocument.createList; -import static com.mongodb.internal.connection.ByteBufBsonDocument.createOne; import static com.mongodb.internal.connection.ReadConcernHelper.getReadConcernDocument; import static com.mongodb.internal.operation.ServerVersionHelper.UNKNOWN_WIRE_VERSION; @@ -143,54 +141,26 @@ public final class CommandMessage extends RequestMessage { } /** - * Create a BsonDocument representing the logical document encoded by an OP_MSG. + * Create a ByteBufBsonDocument representing the logical document encoded by an OP_MSG. *

* The returned document will contain all the fields from the `PAYLOAD_TYPE_0_DOCUMENT` section, as well as all fields represented by * `PAYLOAD_TYPE_1_DOCUMENT_SEQUENCE` sections. + * + *

Note: This document MUST be closed after use, otherwise when using Netty it could report the leaking of resources when the + * underlying {@code byteBuf's} are garbage collected */ - BsonDocument getCommandDocument(final ByteBufferBsonOutput bsonOutput) { + ByteBufBsonDocument getCommandDocument(final ByteBufferBsonOutput bsonOutput) { List byteBuffers = bsonOutput.getByteBuffers(); try { - CompositeByteBuf byteBuf = new CompositeByteBuf(byteBuffers); + CompositeByteBuf compositeByteBuf = new CompositeByteBuf(byteBuffers); try { - byteBuf.position(firstDocumentPosition); - ByteBufBsonDocument byteBufBsonDocument = createOne(byteBuf); - - // If true, it means there is at least one `PAYLOAD_TYPE_1_DOCUMENT_SEQUENCE` section in the OP_MSG - if (byteBuf.hasRemaining()) { - BsonDocument commandBsonDocument = byteBufBsonDocument.toBaseBsonDocument(); - - // Each loop iteration processes one Document Sequence - // When there are no more bytes remaining, there are no more Document Sequences - while (byteBuf.hasRemaining()) { - // skip reading the payload type, we know it is `PAYLOAD_TYPE_1` - byteBuf.position(byteBuf.position() + 1); - int sequenceStart = byteBuf.position(); - int sequenceSizeInBytes = byteBuf.getInt(); - int sectionEnd = sequenceStart + sequenceSizeInBytes; - - String fieldName = getSequenceIdentifier(byteBuf); - // If this assertion fires, it means that the driver has started using document sequences for nested fields. If - // so, this method will need to change in order to append the value to the correct nested document. - assertFalse(fieldName.contains(".")); - - ByteBuf documentsByteBufSlice = byteBuf.duplicate().limit(sectionEnd); - try { - commandBsonDocument.append(fieldName, new BsonArray(createList(documentsByteBufSlice))); - } finally { - documentsByteBufSlice.release(); - } - byteBuf.position(sectionEnd); - } - return commandBsonDocument; - } else { - return byteBufBsonDocument; - } + compositeByteBuf.position(firstDocumentPosition); + return ByteBufBsonDocument.createCommandMessage(compositeByteBuf); } finally { - byteBuf.release(); + compositeByteBuf.release(); } } finally { - byteBuffers.forEach(ByteBuf::release); + ResourceUtil.release(byteBuffers); } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java index 67121ecbe3c..e3eb299a40e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java +++ b/driver-core/src/main/com/mongodb/internal/connection/CompositeByteBuf.java @@ -24,17 +24,55 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; import static java.lang.String.format; import static org.bson.assertions.Assertions.isTrueArgument; import static org.bson.assertions.Assertions.notNull; -class CompositeByteBuf implements ByteBuf { +/** + * A composite {@link ByteBuf} that provides a unified view over a list of component buffers. + * + *

ByteBuf Ownership and Reference Counting

+ *

This class manages the lifecycle of its component buffers with the following rules:

+ *
    + *
  • Constructor: Takes buffers as input but does NOT take ownership. The buffers are made + * read-only via {@link ByteBuf#asReadOnly()}, which creates read only duplicates. + * The original buffers remain owned by the caller.
  • + *
  • {@link #duplicate()}: Creates a new composite with independent position/limit but calls + * {@link ByteBuf#retain()} on each component, incrementing their reference counts. The duplicate + * owns these retained references and releases them when it is released.
  • + *
  • {@link #retain()}: Increments the composite's reference count AND retains all component + * buffers. Each retain() call must be paired with a {@link #release()}.
  • + *
  • {@link #release()}: Decrements the composite's reference count AND releases all component + * buffers. When the count reaches 0, subsequent access will throw {@link IllegalStateException}.
  • + *
+ * + *

Important: The composite's reference count is independent from its components' + * reference counts, but they are kept in sync via {@link #retain()} and {@link #release()} methods.

+ * + *

This class is not part of the public API and may be removed or changed at any time

+ */ +final class CompositeByteBuf implements ByteBuf { private final List components; private final AtomicInteger referenceCount = new AtomicInteger(1); private int position; private int limit; + /** + * Creates a composite buffer from the given list of buffers. + * + *

ByteBuf Ownership: This constructor does NOT take ownership of the input buffers. + * It calls {@link ByteBuf#asReadOnly()} on each buffer, which creates a shallow read-only view that + * shares the same underlying data and reference count as the original. The caller retains ownership + * of the original buffers and is responsible for their lifecycle.

+ * + *

The composite starts with a reference count of 1. When {@link #release()} is called and the + * reference count reaches 0, it does NOT automatically release the original buffers - the caller + * must handle that separately.

+ * + * @param buffers the list of buffers to compose, must not be null or empty + */ CompositeByteBuf(final List buffers) { notNull("buffers", buffers); isTrueArgument("buffer list not empty", !buffers.isEmpty()); @@ -50,7 +88,9 @@ class CompositeByteBuf implements ByteBuf { } private CompositeByteBuf(final CompositeByteBuf from) { - components = from.components; + components = from.components.stream().map(c -> + new Component(c.buffer.retain(), c.offset)) + .collect(Collectors.toList()); position = from.position(); limit = from.limit(); } @@ -277,6 +317,19 @@ public ByteBuf asReadOnly() { throw new UnsupportedOperationException(); } + /** + * Creates a duplicate of this composite buffer with independent position and limit. + * + *

ByteBuf Ownership: The duplicate calls {@link ByteBuf#retain()} on each + * component buffer, incrementing their reference counts. The duplicate owns these retained + * references and starts with its own reference count of 1. The caller is responsible + * for releasing the duplicate when done, which will release the component buffers.

+ * + *

The duplicate shares the underlying buffer data with this composite but has independent + * reference counting and position/limit state.

+ * + * @return a new composite buffer that shares data with this one but has independent state + */ @Override public ByteBuf duplicate() { return new CompositeByteBuf(this); @@ -300,21 +353,42 @@ public int getReferenceCount() { return referenceCount.get(); } + /** + * Increments the reference count of this composite buffer and all component buffers. + * + *

Important: This method retains both the composite's reference count and all + * component buffers. If the reference count is already 0, an {@link IllegalStateException} is thrown. + * Note that if an exception is thrown, the component buffers will have been retained before the + * exception occurs.

+ * + * @return this buffer + * @throws IllegalStateException if the reference count is already 0 + */ @Override public ByteBuf retain() { if (referenceCount.incrementAndGet() == 1) { referenceCount.decrementAndGet(); throw new IllegalStateException("Attempted to increment the reference count when it is already 0"); } + components.forEach(c -> c.buffer.retain()); return this; } + /** + * Decrements the reference count of this composite buffer and all component buffers. + * + *

Important: This method releases both the composite's reference count and all + * component buffers. All component buffers are released even if an exception occurs.

+ * + * @throws IllegalStateException if the reference count is already 0 + */ @Override public void release() { if (referenceCount.decrementAndGet() < 0) { referenceCount.incrementAndGet(); throw new IllegalStateException("Attempted to decrement the reference count below 0"); } + components.forEach(c -> c.buffer.release()); } private Component findComponent(final int index) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java b/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java index bb97517d315..6eeea6b030b 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java @@ -187,11 +187,17 @@ class ServerMonitor extends Thread implements AutoCloseable { @Override public void close() { - interrupt(); - InternalConnection connection = this.connection; - if (connection != null) { - connection.close(); + InternalConnection localConnection = withLock(lock, () -> { + InternalConnection result = connection; + connection = null; + return result; + }); + + if (localConnection != null) { + localConnection.close(); } + + interrupt(); } @Override @@ -293,24 +299,36 @@ private ServerDescription setupNewConnectionAndGetInitialDescription(final boole */ private ServerDescription doHeartbeat(final ServerDescription currentServerDescription, final boolean shouldStreamResponses) { + // Check if monitor was closed or connection is unusable + InternalConnection localConnection = withLock(lock, () -> { + if (isClosed || connection == null || connection.isClosed()) { + return null; + } + return connection; + }); + + if (localConnection == null) { + throw new MongoSocketException("Monitor closed", serverId.getAddress()); + } + try { OperationContext operationContext = operationContextFactory.create(); - if (!connection.hasMoreToCome()) { + if (!localConnection.hasMoreToCome()) { BsonDocument helloDocument = new BsonDocument(getHandshakeCommandName(currentServerDescription), new BsonInt32(1)) .append("helloOk", BsonBoolean.TRUE); if (shouldStreamResponses) { helloDocument.append("topologyVersion", assertNotNull(currentServerDescription.getTopologyVersion()).asDocument()); helloDocument.append("maxAwaitTimeMS", new BsonInt64(serverSettings.getHeartbeatFrequency(MILLISECONDS))); } - connection.send(createCommandMessage(helloDocument, connection, currentServerDescription), new BsonDocumentCodec(), + localConnection.send(createCommandMessage(helloDocument, localConnection, currentServerDescription), new BsonDocumentCodec(), operationContext); } BsonDocument helloResult; if (shouldStreamResponses) { - helloResult = connection.receive(new BsonDocumentCodec(), operationContextWithAdditionalTimeout(operationContext)); + helloResult = localConnection.receive(new BsonDocumentCodec(), operationContextWithAdditionalTimeout(operationContext)); } else { - helloResult = connection.receive(new BsonDocumentCodec(), operationContext); + helloResult = localConnection.receive(new BsonDocumentCodec(), operationContext); } logAndNotifyHeartbeatSucceeded(shouldStreamResponses, helloResult); return createServerDescription(serverId.getAddress(), helloResult, roundTripTimeSampler.getAverage(), @@ -322,10 +340,18 @@ private ServerDescription doHeartbeat(final ServerDescription currentServerDescr } private void logAndNotifyHeartbeatStarted(final boolean shouldStreamResponses) { - alreadyLoggedHeartBeatStarted = true; - logHeartbeatStarted(serverId, connection.getDescription(), shouldStreamResponses); - serverMonitorListener.serverHearbeatStarted(new ServerHeartbeatStartedEvent( - connection.getDescription().getConnectionId(), shouldStreamResponses)); + ConnectionDescription description = connection.getDescription(); + if (description != null) { + alreadyLoggedHeartBeatStarted = true; + logHeartbeatStarted(serverId, description, shouldStreamResponses); + serverMonitorListener.serverHearbeatStarted(new ServerHeartbeatStartedEvent( + description.getConnectionId(), shouldStreamResponses)); + } else { + // Connection not fully established yet - skip logging for this heartbeat + if (LOGGER.isDebugEnabled()) { + LOGGER.debug(format("Skipping heartbeat started event for %s - connection description not available", serverId)); + } + } } private void logAndNotifyHeartbeatSucceeded(final boolean shouldStreamResponses, final BsonDocument helloResult) { @@ -343,12 +369,20 @@ private void logAndNotifyHeartbeatSucceeded(final boolean shouldStreamResponses, private void logAndNotifyHeartbeatFailed(final boolean shouldStreamResponses, final Exception e) { alreadyLoggedHeartBeatStarted = false; long elapsedTimeNanos = getElapsedTimeNanos(); - logHeartbeatFailed(serverId, connection.getDescription(), shouldStreamResponses, elapsedTimeNanos, e); - serverMonitorListener.serverHeartbeatFailed( - new ServerHeartbeatFailedEvent(connection.getDescription().getConnectionId(), elapsedTimeNanos, - shouldStreamResponses, e)); - } + ConnectionDescription description = connection != null ? connection.getDescription() : null; + if (description != null) { + logHeartbeatFailed(serverId, description, shouldStreamResponses, elapsedTimeNanos, e); + serverMonitorListener.serverHeartbeatFailed( + new ServerHeartbeatFailedEvent(description.getConnectionId(), elapsedTimeNanos, + shouldStreamResponses, e)); + } else { + // Log failure without connection details + if (LOGGER.isDebugEnabled()) { + LOGGER.debug(format("Heartbeat failed for %s but connection description not available", serverId), e); + } + } + } private long getElapsedTimeNanos() { return System.nanoTime() - lookupStartTimeNanos; } @@ -514,11 +548,17 @@ private class RoundTripTimeMonitor extends Thread implements AutoCloseable { @Override public void close() { - interrupt(); - InternalConnection connection = this.connection; - if (connection != null) { - connection.close(); + InternalConnection localConnection = withLock(lock, () -> { + InternalConnection result = connection; + connection = null; + return result; + }); + + if (localConnection != null) { + localConnection.close(); } + + interrupt(); } @Override @@ -552,13 +592,45 @@ public void run() { } private void initialize() { - connection = null; - connection = internalConnectionFactory.create(serverId); - connection.open(operationContextFactory.create()); - roundTripTimeSampler.addSample(connection.getInitialServerDescription().getRoundTripTimeNanos()); + boolean shouldProceed = withLock(lock, () -> !isClosed); + + if (!shouldProceed) { + return; + } + + InternalConnection newConnection = internalConnectionFactory.create(serverId); + newConnection.open(operationContextFactory.create()); + + // Check again after the potentially long open() operation + boolean stillValid = withLock(lock, () -> { + if (!isClosed) { + connection = newConnection; + return true; + } + return false; + }); + + if (stillValid) { + roundTripTimeSampler.addSample(newConnection.getInitialServerDescription().getRoundTripTimeNanos()); + } else { + // Monitor was closed during open(), clean up the connection + newConnection.close(); + } } private void pingServer(final InternalConnection connection) { + // Atomically check if monitor was closed and connection is still valid + boolean shouldProceed = withLock(lock, () -> + !isClosed && this.connection == connection + ); + + if (!shouldProceed) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug(format("Skipping ping for %s - monitor closed or connection changed", serverId)); + } + return; // Monitor closed or connection changed, skip ping + } + long start = System.nanoTime(); OperationContext operationContext = operationContextFactory.create(); executeCommand("admin", diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 6b20c467191..7d10e87b00e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -53,12 +53,10 @@ import com.mongodb.internal.session.SessionContext; import com.mongodb.internal.time.Timeout; import com.mongodb.lang.Nullable; -import org.bson.BsonBinaryReader; import org.bson.BsonDocument; import org.bson.ByteBuf; import org.bson.codecs.BsonDocumentCodec; import org.bson.codecs.Decoder; -import org.bson.io.ByteBufferBsonInput; import java.io.IOException; import java.net.SocketTimeoutException; @@ -397,7 +395,6 @@ public T sendAndReceive(final CommandMessage message, final Decoder decod public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final OperationContext operationContext, final SingleResultCallback callback) { - AsyncSupplier sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal( message, decoder, operationContext, c); beginAsync().thenSupply(c -> { @@ -444,46 +441,44 @@ private T sendAndReceiveInternal(final CommandMessage message, final Decoder Span tracingSpan; try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this)) { message.encode(bsonOutput, operationContext); - tracingSpan = operationContext - .getTracingManager() - .createTracingSpan(message, - operationContext, - () -> message.getCommandDocument(bsonOutput), - cmdName -> SECURITY_SENSITIVE_COMMANDS.contains(cmdName) - || SECURITY_SENSITIVE_HELLO_COMMANDS.contains(cmdName), - () -> getDescription().getServerAddress(), - () -> getDescription().getConnectionId() - ); - - boolean isLoggingCommandNeeded = isLoggingCommandNeeded(); - boolean isTracingCommandPayloadNeeded = tracingSpan != null && operationContext.getTracingManager().isCommandPayloadEnabled(); - - // Only hydrate the command document if necessary - BsonDocument commandDocument = null; - if (isLoggingCommandNeeded || isTracingCommandPayloadNeeded) { - commandDocument = message.getCommandDocument(bsonOutput); - } - if (isLoggingCommandNeeded) { - commandEventSender = new LoggingCommandEventSender( - SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, - operationContext, message, commandDocument, - COMMAND_PROTOCOL_LOGGER, loggerSettings); - commandEventSender.sendStartedEvent(); - } else { - commandEventSender = new NoOpCommandEventSender(); - } - if (isTracingCommandPayloadNeeded) { - tracingSpan.tagHighCardinality(QUERY_TEXT.asString(), commandDocument); - } + try (ByteBufBsonDocument commandDocument = message.getCommandDocument(bsonOutput)) { + tracingSpan = operationContext + .getTracingManager() + .createTracingSpan(message, + operationContext, + commandDocument, + cmdName -> SECURITY_SENSITIVE_COMMANDS.contains(cmdName) + || SECURITY_SENSITIVE_HELLO_COMMANDS.contains(cmdName), + () -> getDescription().getServerAddress(), + () -> getDescription().getConnectionId() + ); + + boolean isLoggingCommandNeeded = isLoggingCommandNeeded(); + + if (isLoggingCommandNeeded) { + commandEventSender = new LoggingCommandEventSender( + SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, + operationContext, message, commandDocument, + COMMAND_PROTOCOL_LOGGER, loggerSettings); + commandEventSender.sendStartedEvent(); + } else { + commandEventSender = new NoOpCommandEventSender(); + } - try { - sendCommandMessage(message, bsonOutput, operationContext); - } catch (Exception e) { - if (tracingSpan != null) { - tracingSpan.error(e); + boolean isTracingCommandPayloadNeeded = tracingSpan != null && operationContext.getTracingManager().isCommandPayloadEnabled(); + if (isTracingCommandPayloadNeeded) { + tracingSpan.tagHighCardinality(QUERY_TEXT.asString(), commandDocument); + } + + try { + sendCommandMessage(commandDocument.getFirstKey(), message, bsonOutput, operationContext); + } catch (Exception e) { + if (tracingSpan != null) { + tracingSpan.error(e); + } + commandEventSender.sendFailedEvent(e); + throw e; } - commandEventSender.sendFailedEvent(e); - throw e; } } @@ -502,7 +497,9 @@ private T sendAndReceiveInternal(final CommandMessage message, final Decoder public void send(final CommandMessage message, final Decoder decoder, final OperationContext operationContext) { try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this)) { message.encode(bsonOutput, operationContext); - sendCommandMessage(message, bsonOutput, operationContext); + try (ByteBufBsonDocument commandDocument = message.getCommandDocument(bsonOutput)) { + sendCommandMessage(commandDocument.getFirstKey(), message, bsonOutput, operationContext); + } if (message.isResponseExpected()) { hasMoreToCome = true; } @@ -520,47 +517,26 @@ public boolean hasMoreToCome() { return hasMoreToCome; } - private void sendCommandMessage(final CommandMessage message, final ByteBufferBsonOutput bsonOutput, - final OperationContext operationContext) { - - Compressor localSendCompressor = sendCompressor; - if (localSendCompressor == null || SECURITY_SENSITIVE_COMMANDS.contains(message.getCommandDocument(bsonOutput).getFirstKey())) { - trySendMessage(message, bsonOutput, operationContext); - } else { - ByteBufferBsonOutput compressedBsonOutput; - List byteBuffers = bsonOutput.getByteBuffers(); - try { - CompressedMessage compressedMessage = new CompressedMessage(message.getOpCode(), byteBuffers, localSendCompressor, - getMessageSettings(description, initialServerDescription)); - compressedBsonOutput = new ByteBufferBsonOutput(this); - compressedMessage.encode(compressedBsonOutput, operationContext); - } finally { - ResourceUtil.release(byteBuffers); - bsonOutput.close(); - } - trySendMessage(message, compressedBsonOutput, operationContext); - } - responseTo = message.getId(); - } - - private void trySendMessage(final CommandMessage message, final ByteBufferBsonOutput bsonOutput, - final OperationContext operationContext) { - Timeout.onExistsAndExpired(operationContext.getTimeoutContext().timeoutIncludingRoundTrip(), () -> { - throw TimeoutContext.createMongoRoundTripTimeoutException(); - }); - List byteBuffers = bsonOutput.getByteBuffers(); + private void sendCommandMessage(final String commandName, final CommandMessage message, + final ByteBufferBsonOutput bsonOutput, final OperationContext operationContext) { + List messageByteBuffers = getMessageByteBuffers(commandName, message, bsonOutput, operationContext); try { - sendMessage(byteBuffers, message.getId(), operationContext); + Timeout.onExistsAndExpired(operationContext.getTimeoutContext().timeoutIncludingRoundTrip(), () -> { + throw TimeoutContext.createMongoRoundTripTimeoutException(); + }); + sendMessage(messageByteBuffers, message.getId(), operationContext); } finally { - ResourceUtil.release(byteBuffers); - bsonOutput.close(); + ResourceUtil.release(messageByteBuffers); } + responseTo = message.getId(); } private T receiveCommandMessageResponse(final Decoder decoder, final CommandEventSender commandEventSender, final OperationContext operationContext, @Nullable final Span tracingSpan) { boolean commandSuccessful = false; - try (ResponseBuffers responseBuffers = receiveResponseBuffers(operationContext)) { + ResponseBuffers responseBuffers = null; + try { + responseBuffers = receiveResponseBuffers(operationContext); updateSessionContext(operationContext.getSessionContext(), responseBuffers); if (!isCommandOk(responseBuffers)) { throw getCommandFailureException(responseBuffers.getResponseDocument(responseTo, @@ -591,67 +567,87 @@ private T receiveCommandMessageResponse(final Decoder decoder, final Comm } throw e; } finally { + if (responseBuffers != null) { + responseBuffers.close(); + } if (tracingSpan != null) { tracingSpan.end(); } } } + private List getMessageByteBuffers(final String commandName, final CommandMessage message, + final ByteBufferBsonOutput bsonOutput, final OperationContext operationContext) { + Compressor localSendCompressor = sendCompressor; + List messageByteBuffers; + // Check if compressed + if (localSendCompressor == null || SECURITY_SENSITIVE_COMMANDS.contains(commandName)) { + messageByteBuffers = bsonOutput.getByteBuffers(); + } else { + List byteBuffers = bsonOutput.getByteBuffers(); + try { + CompressedMessage compressedMessage = new CompressedMessage(message.getOpCode(), byteBuffers, localSendCompressor, + getMessageSettings(description, initialServerDescription)); + try (ByteBufferBsonOutput compressedBsonOutput = new ByteBufferBsonOutput(this)) { + compressedMessage.encode(compressedBsonOutput, operationContext); + messageByteBuffers = compressedBsonOutput.getByteBuffers(); + } + } finally { + ResourceUtil.release(byteBuffers); + } + } + return messageByteBuffers; + } + private void sendAndReceiveAsyncInternal(final CommandMessage message, final Decoder decoder, - final OperationContext operationContext, final SingleResultCallback callback) { + final OperationContext operationContext, final SingleResultCallback callback) { if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); return; } + // Async try with resources release after the write ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this); - ByteBufferBsonOutput compressedBsonOutput = new ByteBufferBsonOutput(this); - try { message.encode(bsonOutput, operationContext); + String commandName; CommandEventSender commandEventSender; - if (isLoggingCommandNeeded()) { - BsonDocument commandDocument = message.getCommandDocument(bsonOutput); - commandEventSender = new LoggingCommandEventSender( - SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, - operationContext, message, commandDocument, - COMMAND_PROTOCOL_LOGGER, loggerSettings); - } else { - commandEventSender = new NoOpCommandEventSender(); - } - - commandEventSender.sendStartedEvent(); - Compressor localSendCompressor = sendCompressor; - if (localSendCompressor == null || SECURITY_SENSITIVE_COMMANDS.contains(message.getCommandDocument(bsonOutput).getFirstKey())) { - sendCommandMessageAsync(message.getId(), decoder, operationContext, callback, bsonOutput, commandEventSender, - message.isResponseExpected()); - } else { - List byteBuffers = bsonOutput.getByteBuffers(); - try { - CompressedMessage compressedMessage = new CompressedMessage(message.getOpCode(), byteBuffers, localSendCompressor, - getMessageSettings(description, initialServerDescription)); - compressedMessage.encode(compressedBsonOutput, operationContext); - } finally { - ResourceUtil.release(byteBuffers); - bsonOutput.close(); + try (ByteBufBsonDocument commandDocument = message.getCommandDocument(bsonOutput)) { + commandName = commandDocument.getFirstKey(); + if (isLoggingCommandNeeded()) { + commandEventSender = new LoggingCommandEventSender( + SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, + operationContext, message, commandDocument, + COMMAND_PROTOCOL_LOGGER, loggerSettings); + } else { + commandEventSender = new NoOpCommandEventSender(); } - sendCommandMessageAsync(message.getId(), decoder, operationContext, callback, compressedBsonOutput, commandEventSender, - message.isResponseExpected()); + commandEventSender.sendStartedEvent(); } + + List messageByteBuffers = getMessageByteBuffers(commandName, message, bsonOutput, operationContext); + sendCommandMessageAsync(messageByteBuffers, message.getId(), decoder, operationContext, + commandEventSender, message.isResponseExpected(), (r, t) -> { + ResourceUtil.release(messageByteBuffers); + bsonOutput.close(); // Close AFTER async write completes + if (t != null) { + callback.onResult(null, t); + } else { + callback.onResult(r, null); + } + }); } catch (Throwable t) { bsonOutput.close(); - compressedBsonOutput.close(); callback.onResult(null, t); } } - private void sendCommandMessageAsync(final int messageId, final Decoder decoder, final OperationContext operationContext, - final SingleResultCallback callback, final ByteBufferBsonOutput bsonOutput, - final CommandEventSender commandEventSender, final boolean responseExpected) { + private void sendCommandMessageAsync(final List messageByteBuffers, final int messageId, final Decoder decoder, + final OperationContext operationContext, final CommandEventSender commandEventSender, + final boolean responseExpected, final SingleResultCallback callback) { boolean[] shouldReturn = {false}; Timeout.onExistsAndExpired(operationContext.getTimeoutContext().timeoutIncludingRoundTrip(), () -> { - bsonOutput.close(); MongoOperationTimeoutException operationTimeoutException = TimeoutContext.createMongoRoundTripTimeoutException(); commandEventSender.sendFailedEvent(operationTimeoutException); callback.onResult(null, operationTimeoutException); @@ -661,10 +657,7 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d return; } - List byteBuffers = bsonOutput.getByteBuffers(); - sendMessageAsync(byteBuffers, messageId, operationContext, (result, t) -> { - ResourceUtil.release(byteBuffers); - bsonOutput.close(); + sendMessageAsync(messageByteBuffers, messageId, operationContext, (result, t) -> { if (t != null) { commandEventSender.sendFailedEvent(t); callback.onResult(null, t); @@ -682,18 +675,16 @@ private void sendCommandMessageAsync(final int messageId, final Decoder d T commandResult; try { updateSessionContext(operationContext.getSessionContext(), responseBuffers); - boolean commandOk = - isCommandOk(new BsonBinaryReader(new ByteBufferBsonInput(responseBuffers.getBodyByteBuffer()))); - responseBuffers.reset(); - if (!commandOk) { + + if (!isCommandOk(responseBuffers)) { MongoException commandFailureException = getCommandFailureException( responseBuffers.getResponseDocument(messageId, new BsonDocumentCodec()), description.getServerAddress(), operationContext.getTimeoutContext()); commandEventSender.sendFailedEvent(commandFailureException); throw commandFailureException; } - commandEventSender.sendSucceededEvent(responseBuffers); + commandEventSender.sendSucceededEvent(responseBuffers); commandResult = getCommandResult(decoder, responseBuffers, messageId, operationContext.getTimeoutContext()); } catch (Throwable localThrowable) { callback.onResult(null, localThrowable); @@ -744,13 +735,12 @@ public void sendMessageAsync( final int lastRequestId, final OperationContext operationContext, final SingleResultCallback callback) { - beginAsync().thenRun((c) -> { + + beginAsync().thenRunTryCatchAsyncBlocks(c -> { notNull("stream is open", stream); if (isClosed()) { throw new MongoSocketClosedException("Cannot write to a closed stream", getServerAddress()); } - c.complete(c); - }).thenRunTryCatchAsyncBlocks(c -> { stream.writeAsync(byteBuffers, operationContext, c.asHandler()); }, Exception.class, (e, c) -> { try { @@ -887,6 +877,7 @@ private MongoSocketReadTimeoutException createReadTimeoutException(final Socket } private ResponseBuffers receiveResponseBuffers(final OperationContext operationContext) { + ByteBuf messageBuffer = null; try { ByteBuf messageHeaderBuffer = stream.read(MESSAGE_HEADER_LENGTH, operationContext); MessageHeader messageHeader; @@ -896,32 +887,41 @@ private ResponseBuffers receiveResponseBuffers(final OperationContext operationC messageHeaderBuffer.release(); } - ByteBuf messageBuffer = stream.read(messageHeader.getMessageLength() - MESSAGE_HEADER_LENGTH, operationContext); - boolean releaseMessageBuffer = true; - try { - if (messageHeader.getOpCode() == OP_COMPRESSED.getValue()) { - CompressedHeader compressedHeader = new CompressedHeader(messageBuffer, messageHeader); + messageBuffer = stream.read(messageHeader.getMessageLength() - MESSAGE_HEADER_LENGTH, operationContext); + ResponseBuffers responseBuffers = getResponseBuffers(messageHeader, messageBuffer); + messageBuffer = null; // Transfer ownership to ResponseBuffers + return responseBuffers; + } catch (Throwable t) { + if (messageBuffer != null) { + messageBuffer.release(); + } + close(); + throw translateReadException(t, operationContext); + } + } - Compressor compressor = getCompressor(compressedHeader); + private ResponseBuffers getResponseBuffers(final MessageHeader messageHeader, final ByteBuf messageBuffer) { + boolean releaseMessageBuffer = true; + try { + if (messageHeader.getOpCode() == OP_COMPRESSED.getValue()) { + CompressedHeader compressedHeader = new CompressedHeader(messageBuffer, messageHeader); - ByteBuf buffer = getBuffer(compressedHeader.getUncompressedSize()); - compressor.uncompress(messageBuffer, buffer); + Compressor compressor = getCompressor(compressedHeader); - buffer.flip(); - return new ResponseBuffers(new ReplyHeader(buffer, compressedHeader), buffer); - } else { - ResponseBuffers responseBuffers = new ResponseBuffers(new ReplyHeader(messageBuffer, messageHeader), messageBuffer); - releaseMessageBuffer = false; - return responseBuffers; - } - } finally { - if (releaseMessageBuffer) { - messageBuffer.release(); - } + ByteBuf buffer = getBuffer(compressedHeader.getUncompressedSize()); + compressor.uncompress(messageBuffer, buffer); + + buffer.flip(); + return new ResponseBuffers(new ReplyHeader(buffer, compressedHeader), buffer); + } else { + ResponseBuffers responseBuffers = new ResponseBuffers(new ReplyHeader(messageBuffer, messageHeader), messageBuffer); + releaseMessageBuffer = false; + return responseBuffers; + } + } finally { + if (releaseMessageBuffer) { + messageBuffer.release(); } - } catch (Throwable t) { - close(); - throw translateReadException(t, operationContext); } } @@ -957,14 +957,15 @@ public void onResult(@Nullable final ByteBuf result, @Nullable final Throwable t try { assertNotNull(result); MessageHeader messageHeader = new MessageHeader(result, description.getMaxMessageSize()); + result.release(); // Release header buffer + readAsync(messageHeader.getMessageLength() - MESSAGE_HEADER_LENGTH, operationContext, new MessageCallback(messageHeader)); } catch (Throwable localThrowable) { - callback.onResult(null, localThrowable); - } finally { if (result != null) { result.release(); } + callback.onResult(null, localThrowable); } } @@ -981,37 +982,18 @@ public void onResult(@Nullable final ByteBuf result, @Nullable final Throwable t callback.onResult(null, t); return; } - boolean releaseResult = true; - assertNotNull(result); + ResponseBuffers responseBuffers = null; try { - ReplyHeader replyHeader; - ByteBuf responseBuffer; - if (messageHeader.getOpCode() == OP_COMPRESSED.getValue()) { - try { - CompressedHeader compressedHeader = new CompressedHeader(result, messageHeader); - Compressor compressor = getCompressor(compressedHeader); - ByteBuf buffer = getBuffer(compressedHeader.getUncompressedSize()); - compressor.uncompress(result, buffer); - - buffer.flip(); - replyHeader = new ReplyHeader(buffer, compressedHeader); - responseBuffer = buffer; - } finally { - releaseResult = false; - result.release(); - } - } else { - replyHeader = new ReplyHeader(result, messageHeader); - responseBuffer = result; - releaseResult = false; - } - callback.onResult(new ResponseBuffers(replyHeader, responseBuffer), null); + assertNotNull(result); + responseBuffers = getResponseBuffers(messageHeader, result); + callback.onResult(responseBuffers, null); } catch (Throwable localThrowable) { - callback.onResult(null, localThrowable); - } finally { - if (releaseResult) { + if (responseBuffers != null) { + responseBuffers.close(); + } else if (result != null) { result.release(); } + callback.onResult(null, localThrowable); } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java index 72235b46760..965ac9e951a 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java +++ b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyByteBuf.java @@ -20,12 +20,13 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.concurrent.atomic.AtomicInteger; /** *

This class is not part of the public API and may be removed or changed at any time

*/ public final class NettyByteBuf implements ByteBuf { - + private final AtomicInteger referenceCount = new AtomicInteger(1); private io.netty.buffer.ByteBuf proxied; private boolean isWriting = true; @@ -271,17 +272,26 @@ public ByteBuffer asNIO() { @Override public int getReferenceCount() { - return proxied.refCnt(); + return referenceCount.get(); } @Override public ByteBuf retain() { + if (referenceCount.incrementAndGet() == 1) { + referenceCount.decrementAndGet(); + throw new IllegalStateException("Attempted to increment the reference count when it is already 0"); + } proxied.retain(); return this; } @Override public void release() { + int newRefCount = referenceCount.decrementAndGet(); + if (newRefCount < 0) { + referenceCount.incrementAndGet(); + throw new IllegalStateException("Attempted to decrement the reference count below 0"); + } proxied.release(); } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyStream.java b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyStream.java index 76e10653454..ccd2cda298b 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/netty/NettyStream.java +++ b/driver-core/src/main/com/mongodb/internal/connection/netty/NettyStream.java @@ -29,6 +29,8 @@ import com.mongodb.connection.SslSettings; import com.mongodb.internal.connection.OperationContext; import com.mongodb.internal.connection.Stream; +import com.mongodb.internal.diagnostics.logging.Logger; +import com.mongodb.internal.diagnostics.logging.Loggers; import com.mongodb.lang.Nullable; import com.mongodb.spi.dns.InetAddressResolver; import io.netty.bootstrap.Bootstrap; @@ -86,6 +88,24 @@ * {@linkplain #readAsync(int, OperationContext, AsyncCompletionHandler) asynchronous}) * are not supported by {@link NettyStream}. * However, this class does not have a fail-fast mechanism checking for such situations. + * + *

ByteBuf Ownership and Reference Counting

+ *

This class manages Netty {@link io.netty.buffer.ByteBuf} instances which use reference counting for memory management. + * The following ownership rules apply:

+ *
    + *
  • Inbound buffers from Netty: When Netty delivers a buffer via {@link InboundBufferHandler#channelRead0}, + * the buffer is initially owned by Netty. To keep the buffer, we actively call {@link io.netty.buffer.ByteBuf#retain()} + * to increment its reference count before adding it to {@link #pendingInboundBuffers}.
  • + *
  • {@link #pendingInboundBuffers}: All buffers in this queue have been retained and are owned by this class. + * They must be released when consumed or when the stream is closed.
  • + *
  • Read completion: When a read operation completes successfully, ownership of the returned {@link ByteBuf} + * transfers to the caller (via {@link AsyncCompletionHandler#completed}). The caller is responsible for releasing it.
  • + *
  • Handler exceptions: If {@link AsyncCompletionHandler#completed} throws an exception, ownership was not + * transferred, and this class releases the buffer.
  • + *
  • Stream closure: When {@link #close()} is called, all buffers in {@link #pendingInboundBuffers} are released. + * Callers must ensure any previously returned buffers are released before calling close.
  • + *
+ * *
* 1We cannot simply say that read methods are not allowed be run concurrently because strictly speaking they are allowed, * as explained below. @@ -113,6 +133,7 @@ * is invoked after the first operation has completed reading despite the method has not returned yet. */ final class NettyStream implements Stream { + private static final Logger LOGGER = Loggers.getLogger("connection"); private static final byte NO_SCHEDULE_TIME = 0; private final ServerAddress address; private final InetAddressResolver inetAddressResolver; @@ -127,6 +148,19 @@ final class NettyStream implements Stream { private boolean isClosed; private volatile Channel channel; + /** + * Queue of inbound buffers received from Netty that have not yet been consumed by read operations. + * + *

Ownership: All buffers in this queue have been {@link io.netty.buffer.ByteBuf#retain() retained} + * and are owned by this class. When a buffer is removed from this queue, the remover takes ownership and + * is responsible for either:

+ *
    + *
  • Transferring ownership to a caller (e.g., via {@link AsyncCompletionHandler#completed}), or
  • + *
  • Releasing the buffer via {@link io.netty.buffer.ByteBuf#release()}
  • + *
+ * + *

When the stream is {@link #close() closed}, all remaining buffers are released.

+ */ private final LinkedList pendingInboundBuffers = new LinkedList<>(); private final Lock lock = new ReentrantLock(); // access to the fields `pendingReader`, `pendingException` is guarded by `lock` @@ -227,6 +261,7 @@ public void initChannel(final SocketChannel ch) { @Override public void write(final List buffers, final OperationContext operationContext) throws IOException { + validateConnectionState(); FutureAsyncCompletionHandler future = new FutureAsyncCompletionHandler<>(); writeAsync(buffers, operationContext, future); future.get(); @@ -242,29 +277,49 @@ public ByteBuf read(final int numBytes, final OperationContext operationContext) @Override public void writeAsync(final List buffers, final OperationContext operationContext, final AsyncCompletionHandler handler) { + // Early validation before allocating resources + if (isClosed) { + handler.failed(new MongoSocketException("Stream is closed", address)); + return; + } + Channel localChannel = channel; + if (localChannel == null || !localChannel.isActive()) { + handler.failed(new MongoSocketException("Channel is not active", address)); + return; + } + CompositeByteBuf composite = PooledByteBufAllocator.DEFAULT.compositeBuffer(); - for (ByteBuf cur : buffers) { - // The Netty framework releases `CompositeByteBuf` after writing - // (see https://netty.io/wiki/reference-counted-objects.html#outbound-messages), - // which results in the buffer we pass to `CompositeByteBuf.addComponent` being released. - // However, `CompositeByteBuf.addComponent` does not retain this buffer, - // which means we must retain it to conform to the `Stream.writeAsync` contract. - composite.addComponent(true, ((NettyByteBuf) cur).asByteBuf().retain()); - } - - long writeTimeoutMS = operationContext.getTimeoutContext().getWriteTimeoutMS(); - final Optional writeTimeoutHandler = addWriteTimeoutHandler(writeTimeoutMS); - channel.writeAndFlush(composite).addListener((ChannelFutureListener) future -> { - writeTimeoutHandler.map(w -> channel.pipeline().remove(w)); - if (!future.isSuccess()) { - handler.failed(future.cause()); - } else { - handler.completed(null); + + try { + for (ByteBuf cur : buffers) { + // The Netty framework releases `CompositeByteBuf` after writing + // (see https://netty.io/wiki/reference-counted-objects.html#outbound-messages), + // which results in the buffer we pass to `CompositeByteBuf.addComponent` being released. + // However, `CompositeByteBuf.addComponent` does not retain this buffer, + // which means we must retain it to conform to the `Stream.writeAsync` contract. + composite.addComponent(true, ((NettyByteBuf) cur).asByteBuf().retain()); } - }); + + long writeTimeoutMS = operationContext.getTimeoutContext().getWriteTimeoutMS(); + final Optional writeTimeoutHandler = addWriteTimeoutHandler(localChannel, writeTimeoutMS); + + localChannel.writeAndFlush(composite).addListener((ChannelFutureListener) future -> { + writeTimeoutHandler.map(w -> localChannel.pipeline().remove(w)); + + if (!future.isSuccess()) { + handler.failed(future.cause()); + } else { + handler.completed(null); + } + }); + } catch (Throwable t) { + // If we fail before submitting the write, release the composite + composite.release(); + handler.failed(t); + } } - private Optional addWriteTimeoutHandler(final long writeTimeoutMS) { + private Optional addWriteTimeoutHandler(final Channel channel, final long writeTimeoutMS) { if (writeTimeoutMS != NO_SCHEDULE_TIME) { WriteTimeoutHandler writeTimeoutHandler = new WriteTimeoutHandler(writeTimeoutMS, MILLISECONDS); channel.pipeline().addBefore("ChannelInboundHandlerAdapter", "WriteTimeoutHandler", writeTimeoutHandler); @@ -285,52 +340,76 @@ public void readAsync(final int numBytes, final OperationContext operationContex * Timeouts may be scheduled only by the public read methods. Taking into account that concurrent pending * readers are not allowed, there must not be a situation when threads attempt to schedule a timeout * before the previous one is either cancelled or completed. + * + *

Buffer Ownership

+ *

When this method completes a read operation successfully:

+ *
    + *
  • Buffers are removed from {@link #pendingInboundBuffers} and assembled into a composite buffer
  • + *
  • Ownership of the composite buffer is transferred to the handler via {@link AsyncCompletionHandler#completed}
  • + *
  • If the handler throws an exception, the buffer is released by {@link #invokeHandlerWithBuffer}
  • + *
+ *

If an exception occurs during buffer assembly, the partially-built composite is released before propagating the exception.

*/ private void readAsync(final int numBytes, final AsyncCompletionHandler handler, final long readTimeoutMillis) { ByteBuf buffer = null; - Throwable exceptionResult = null; + Throwable exceptionResult; lock.lock(); try { - exceptionResult = pendingException; - if (exceptionResult == null) { - if (!hasBytesAvailable(numBytes)) { + + if (pendingException == null) { + if (isClosed) { + pendingException = new MongoSocketException("Stream was closed", address); + // Release any pending buffers that were retained before stream was closed + releaseAllPendingInboundBuffers(); + } else if (channel == null || !channel.isActive()) { + pendingException = new MongoSocketException("Channel is not active", address); + // Release any pending buffers that were retained before channel became inactive + releaseAllPendingInboundBuffers(); + } else if (!hasBytesAvailable(numBytes)) { if (pendingReader == null) {//called by a public read method pendingReader = new PendingReader(numBytes, handler, scheduleReadTimeout(readTimeoutTask, readTimeoutMillis)); } } else { CompositeByteBuf composite = allocator.compositeBuffer(pendingInboundBuffers.size()); - int bytesNeeded = numBytes; - for (Iterator iter = pendingInboundBuffers.iterator(); iter.hasNext();) { - io.netty.buffer.ByteBuf next = iter.next(); - int bytesNeededFromCurrentBuffer = Math.min(next.readableBytes(), bytesNeeded); - if (bytesNeededFromCurrentBuffer == next.readableBytes()) { - composite.addComponent(next); - iter.remove(); - } else { - composite.addComponent(next.readRetainedSlice(bytesNeededFromCurrentBuffer)); - } - composite.writerIndex(composite.writerIndex() + bytesNeededFromCurrentBuffer); - bytesNeeded -= bytesNeededFromCurrentBuffer; - if (bytesNeeded == 0) { - break; + try { + int bytesNeeded = numBytes; + for (Iterator iter = pendingInboundBuffers.iterator(); iter.hasNext();) { + io.netty.buffer.ByteBuf next = iter.next(); + int bytesNeededFromCurrentBuffer = Math.min(next.readableBytes(), bytesNeeded); + if (bytesNeededFromCurrentBuffer == next.readableBytes()) { + iter.remove(); // Remove BEFORE adding to composite + composite.addComponent(true, next); + } else { + composite.addComponent(true, next.readRetainedSlice(bytesNeededFromCurrentBuffer)); + } + + bytesNeeded -= bytesNeededFromCurrentBuffer; + if (bytesNeeded == 0) { + break; + } } + buffer = new NettyByteBuf(composite).flip(); + } catch (Throwable t) { + composite.release(); + pendingException = t; } - buffer = new NettyByteBuf(composite).flip(); } } - if (!(exceptionResult == null && buffer == null)//the read operation has completed - && pendingReader != null) {//we need to clear the pending reader + + exceptionResult = pendingException; + if (!(exceptionResult == null && buffer == null) //the read operation has completed + && pendingReader != null) { //we need to clear the pending reader cancel(pendingReader.timeout); this.pendingReader = null; } } finally { lock.unlock(); } + if (exceptionResult != null) { handler.failed(exceptionResult); - } - if (buffer != null) { - handler.completed(buffer); + } else if (buffer != null) { + invokeHandlerWithBuffer(buffer, handler); } } @@ -350,6 +429,10 @@ private void handleReadResponse(@Nullable final io.netty.buffer.ByteBuf buffer, if (buffer != null) { if (isClosed) { pendingException = new MongoSocketException("Received data after the stream was closed.", address); + // Do not retain the buffer since we're not storing it - let it be released by the caller + } else if (channel == null || !channel.isActive()) { + pendingException = new MongoSocketException("Channel is not active during read", address); + // Do not retain the buffer - channel is unusable } else { pendingInboundBuffers.add(buffer.retain()); } @@ -370,22 +453,94 @@ public ServerAddress getAddress() { return address; } + /** + * Closes this stream and releases all resources. + * + *

Buffer Cleanup

+ *

All buffers remaining in {@link #pendingInboundBuffers} are released synchronously while holding + * the lock to prevent race conditions. Buffers are forcefully released (all reference counts dropped) + * to prevent silent leaks.

+ * + *

Write Buffer Handling

+ *

Write buffers use retainedSlice() which creates independent reference counts for Netty and the caller. + * This eliminates the need for explicit tracking - each party manages its own buffer lifecycle independently.

+ * + *

Channel Cleanup

+ *

If a channel exists, it is closed asynchronously. The async listener performs defensive cleanup + * to ensure any buffers added during the close process are also released. This provides defense-in-depth + * against race conditions.

+ * + *

Important: Callers must ensure that any {@link ByteBuf} instances previously returned by + * {@link #read} or {@link #readAsync} have been released before calling this method. This class does not + * track buffers after ownership has been transferred to callers.

+ */ @Override public void close() { - withLock(lock, () -> { - isClosed = true; - if (channel != null) { - channel.close(); + Channel channelToClose = withLock(lock, () -> { + if (!isClosed) { + isClosed = true; + + // Clean up all pending inbound buffers synchronously while holding the lock + // This prevents race conditions where buffers might be added during close + releaseAllPendingInboundBuffers(); + + // Save channel reference for async close, then null it out + Channel localChannel = channel; channel = null; + return localChannel; } - for (Iterator iterator = pendingInboundBuffers.iterator(); iterator.hasNext();) { - io.netty.buffer.ByteBuf nextByteBuf = iterator.next(); - iterator.remove(); - // Drops all retains to prevent silent leaks; assumes callers have already released - // ByteBuffers returned by that NettyStream before calling close. - nextByteBuf.release(nextByteBuf.refCnt()); - } + return null; }); + + + // Close the channel outside the lock to avoid potential deadlocks + if (channelToClose != null) { + channelToClose.close().addListener((ChannelFutureListener) future -> { + // Defensive cleanup: release any buffers that might have been added during close + // This is safe because isClosed=true prevents handleReadResponse from retaining new buffers + withLock(lock, () -> { + try { + releaseAllPendingInboundBuffers(); + } catch (Throwable t) { + // Log but don't propagate - we're in an async callback + LOGGER.warn("Exception while releasing buffers in channel close listener", t); + } + }); + }); + } + } + + /** + * Releases all buffers in {@link #pendingInboundBuffers} and clears the list. + * This method must be called while holding {@link #lock}. + * + *

Each buffer is forcefully released by dropping all its reference counts (not just one) + * to prevent silent leaks in case of reference counting errors elsewhere.

+ * + *

This method is idempotent - it can be safely called multiple times.

+ */ + private void releaseAllPendingInboundBuffers() { + int releasedCount = 0; + int errorCount = 0; + + for (io.netty.buffer.ByteBuf buffer : pendingInboundBuffers) { + try { + int refCnt = buffer.refCnt(); + if (refCnt > 0) { + buffer.release(refCnt); + releasedCount++; + } + } catch (Throwable t) { + errorCount++; + // Log but continue releasing other buffers - we want to release all buffers even if one fails + LOGGER.warn("Exception while releasing buffer with refCount " + buffer.refCnt(), t); + } + } + pendingInboundBuffers.clear(); + + if (LOGGER.isDebugEnabled() && (releasedCount > 0 || errorCount > 0)) { + LOGGER.debug(String.format("Released %d buffers for %s (%d errors)", releasedCount, address, errorCount)); + } } @Override @@ -413,6 +568,52 @@ public ByteBufAllocator getAllocator() { return allocator; } + /** + * Validates that the stream is open and the channel is active. + * + * @throws MongoSocketException if the stream is closed or the channel is not active + */ + private void validateConnectionState() throws MongoSocketException { + if (isClosed) { + throw new MongoSocketException("Stream is closed", address); + } + Channel localChannel = channel; + if (localChannel == null || !localChannel.isActive()) { + throw new MongoSocketException("Channel is not active", address); + } + } + + /** + * Invokes the handler with the buffer, ensuring the buffer is released if the handler throws an exception. + * + *

Ownership Transfer Protocol

+ *

This method implements a safe ownership transfer:

+ *
    + *
  1. The buffer is passed to {@link AsyncCompletionHandler#completed}
  2. + *
  3. If the handler returns normally, ownership has been successfully transferred to the handler/caller
  4. + *
  5. If the handler throws an exception, ownership was NOT transferred, so this method releases the buffer + * before re-throwing the exception
  6. + *
+ * + *

This ensures that buffers are never leaked, regardless of whether the handler succeeds or fails.

+ * + * @param buffer The buffer to pass to the handler. Must not be null. Ownership is transferred to the handler + * on successful completion. + * @param handler The handler to invoke with the buffer. + * @throws RuntimeException if the handler throws an exception (after releasing the buffer) + */ + private void invokeHandlerWithBuffer(final ByteBuf buffer, final AsyncCompletionHandler handler) { + try { + handler.completed(buffer); + } catch (Throwable t) { + // Handler threw an exception, so it didn't take ownership - release the buffer + if (buffer.getReferenceCount() > 0) { + buffer.release(); + } + throw t; + } + } + private void addSslHandler(final SocketChannel channel) { SSLEngine engine; if (sslContext == null) { @@ -525,13 +726,22 @@ private class OpenChannelFutureListener implements ChannelFutureListener { public void operationComplete(final ChannelFuture future) { withLock(lock, () -> { if (future.isSuccess()) { + Channel newChannel = channelFuture.channel(); if (isClosed) { - channelFuture.channel().close(); + // Monitor closed during connection - clean up immediately + if (newChannel != null) { + newChannel.close(); + } + handler.completed(null); + } else if (newChannel == null || !newChannel.isActive()) { + // Channel invalid - treat as failure + handler.failed(new MongoSocketException("Channel is not active after connection", address)); } else { - channel = channelFuture.channel(); - channel.closeFuture().addListener((ChannelFutureListener) future1 -> handleReadResponse(null, new IOException("The connection to the server was closed"))); + channel = newChannel; + channel.closeFuture().addListener((ChannelFutureListener) future1 -> + handleReadResponse(null, new IOException("The connection to the server was closed"))); + handler.completed(null); } - handler.completed(null); } else { if (isClosed) { handler.completed(null); diff --git a/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java b/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java index 2dadd11efec..b26cb396e7b 100644 --- a/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java +++ b/driver-core/src/main/com/mongodb/internal/observability/micrometer/TracingManager.java @@ -37,6 +37,7 @@ import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.CLIENT_CONNECTION_ID; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.COLLECTION; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.COMMAND_NAME; +import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.CURSOR_ID; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.NAMESPACE; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.NETWORK_TRANSPORT; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.QUERY_SUMMARY; @@ -46,7 +47,6 @@ import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.SESSION_ID; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.SYSTEM; import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.TRANSACTION_NUMBER; -import static com.mongodb.internal.observability.micrometer.MongodbObservation.LowCardinalityKeyNames.CURSOR_ID; import static java.lang.System.getenv; /** @@ -178,7 +178,7 @@ public boolean isCommandPayloadEnabled() { * * @param message the command message to trace * @param operationContext the operation context containing tracing and session information - * @param commandDocumentSupplier a supplier that provides the command document when needed + * @param commandDocument the command document, note this is an internally managed resource * @param isSensitiveCommand a predicate that determines if a command is security-sensitive based on its name * @param serverAddressSupplier a supplier that provides the server address when needed * @param connectionIdSupplier a supplier that provides the connection ID when needed @@ -187,26 +187,26 @@ public boolean isCommandPayloadEnabled() { @Nullable public Span createTracingSpan(final CommandMessage message, final OperationContext operationContext, - final Supplier commandDocumentSupplier, + final BsonDocument commandDocument, final Predicate isSensitiveCommand, final Supplier serverAddressSupplier, final Supplier connectionIdSupplier ) { - if (!isEnabled()) { + if (!isEnabled()) { return null; } - BsonDocument command = commandDocumentSupplier.get(); - String commandName = command.getFirstKey(); + + String commandName = commandDocument.getFirstKey(); if (isSensitiveCommand.test(commandName)) { return null; } Span operationSpan = operationContext.getTracingSpan(); - Span span = addSpan(commandName, operationSpan != null ? operationSpan.context() : null); + Span span = addSpan(commandName, operationSpan != null ? operationSpan.context() : null); - if (command.containsKey("getMore")) { - long cursorId = command.getInt64("getMore").longValue(); + if (commandDocument.containsKey("getMore")) { + long cursorId = commandDocument.getInt64("getMore").longValue(); span.tagLowCardinality(CURSOR_ID.withValue(String.valueOf(cursorId))); if (operationSpan != null) { operationSpan.tagLowCardinality(CURSOR_ID.withValue(String.valueOf(cursorId))); diff --git a/driver-core/src/test/functional/com/mongodb/connection/netty/NettyStreamSpecification.groovy b/driver-core/src/test/functional/com/mongodb/connection/netty/NettyStreamSpecification.groovy deleted file mode 100644 index e582e0fc398..00000000000 --- a/driver-core/src/test/functional/com/mongodb/connection/netty/NettyStreamSpecification.groovy +++ /dev/null @@ -1,125 +0,0 @@ -package com.mongodb.connection.netty - -import com.mongodb.MongoSocketException -import com.mongodb.MongoSocketOpenException -import com.mongodb.ServerAddress -import com.mongodb.connection.AsyncCompletionHandler -import com.mongodb.connection.SocketSettings -import com.mongodb.connection.SslSettings -import com.mongodb.internal.connection.netty.NettyStreamFactory -import com.mongodb.spi.dns.InetAddressResolver -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.socket.nio.NioSocketChannel -import spock.lang.IgnoreIf -import spock.lang.Specification -import com.mongodb.spock.Slow - -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - -import static com.mongodb.ClusterFixture.OPERATION_CONTEXT -import static com.mongodb.ClusterFixture.getSslSettings - -class NettyStreamSpecification extends Specification { - - @Slow - @IgnoreIf({ getSslSettings().isEnabled() }) - def 'should successfully connect with working ip address group'() { - given: - SocketSettings socketSettings = SocketSettings.builder().connectTimeout(1000, TimeUnit.MILLISECONDS).build() - SslSettings sslSettings = SslSettings.builder().build() - def inetAddressResolver = new InetAddressResolver() { - @Override - List lookupByName(String host) { - [InetAddress.getByName('192.168.255.255'), - InetAddress.getByName('1.2.3.4'), - InetAddress.getByName('127.0.0.1')] - } - } - def factory = new NettyStreamFactory(inetAddressResolver, socketSettings, sslSettings, new NioEventLoopGroup(), - NioSocketChannel, PooledByteBufAllocator.DEFAULT, null) - - def stream = factory.create(new ServerAddress()) - - when: - stream.open(OPERATION_CONTEXT) - - then: - !stream.isClosed() - } - - @Slow - @IgnoreIf({ getSslSettings().isEnabled() }) - def 'should throw exception with non-working ip address group'() { - given: - SocketSettings socketSettings = SocketSettings.builder().connectTimeout(1000, TimeUnit.MILLISECONDS).build() - SslSettings sslSettings = SslSettings.builder().build() - def inetAddressResolver = new InetAddressResolver() { - @Override - List lookupByName(String host) { - [InetAddress.getByName('192.168.255.255'), - InetAddress.getByName('1.2.3.4'), - InetAddress.getByName('1.2.3.5')] - } - } - def factory = new NettyStreamFactory(inetAddressResolver, socketSettings, sslSettings, new NioEventLoopGroup(), - NioSocketChannel, PooledByteBufAllocator.DEFAULT, null) - - def stream = factory.create(new ServerAddress()) - - when: - stream.open(OPERATION_CONTEXT) - - then: - thrown(MongoSocketOpenException) - } - - @IgnoreIf({ getSslSettings().isEnabled() }) - def 'should fail AsyncCompletionHandler if name resolution fails'() { - given: - def serverAddress = Stub(ServerAddress) - def exception = new MongoSocketException('Temporary failure in name resolution', serverAddress) - serverAddress.getSocketAddresses() >> { throw exception } - - SocketSettings socketSettings = SocketSettings.builder().connectTimeout(1000, TimeUnit.MILLISECONDS).build() - SslSettings sslSettings = SslSettings.builder().build() - def inetAddressResolver = new InetAddressResolver() { - @Override - List lookupByName(String host) { - throw exception - } - } - def stream = new NettyStreamFactory(inetAddressResolver, socketSettings, sslSettings, new NioEventLoopGroup(), - NioSocketChannel, PooledByteBufAllocator.DEFAULT, null) - .create(new ServerAddress()) - def callback = new CallbackErrorHolder() - - when: - stream.openAsync(OPERATION_CONTEXT, callback) - - then: - callback.getError().is(exception) - } - - class CallbackErrorHolder implements AsyncCompletionHandler { - CountDownLatch latch = new CountDownLatch(1) - Throwable throwable = null - - Throwable getError() { - latch.countDown() - throwable - } - - @Override - void completed(Void r) { - latch.await() - } - - @Override - void failed(Throwable t) { - throwable = t - latch.countDown() - } - } -} diff --git a/driver-core/src/test/functional/com/mongodb/connection/netty/NettyStreamTest.java b/driver-core/src/test/functional/com/mongodb/connection/netty/NettyStreamTest.java new file mode 100644 index 00000000000..a72a8228066 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/connection/netty/NettyStreamTest.java @@ -0,0 +1,1110 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.connection.netty; + +import com.mongodb.MongoSocketException; +import com.mongodb.MongoSocketOpenException; +import com.mongodb.ServerAddress; +import com.mongodb.connection.AsyncCompletionHandler; +import com.mongodb.connection.SocketSettings; +import com.mongodb.connection.SslSettings; +import com.mongodb.internal.ResourceUtil; +import com.mongodb.internal.connection.Stream; +import com.mongodb.internal.connection.netty.NettyByteBuf; +import com.mongodb.internal.connection.netty.NettyStreamFactory; +import com.mongodb.spi.dns.InetAddressResolver; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.ReferenceCounted; +import org.bson.ByteBuf; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIf; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static com.mongodb.ClusterFixture.OPERATION_CONTEXT; +import static com.mongodb.ClusterFixture.getSslSettings; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + + +@SuppressWarnings("deprecation") +@DisplayName("NettyStream - Connection, Validation & Leak Prevention Tests") +class NettyStreamTest { + + private NioEventLoopGroup eventLoopGroup; + private NettyStreamFactory factory; + private Stream stream; + private TrackingNettyByteBufAllocator trackingAllocator; + + @BeforeEach + void setUp() { + eventLoopGroup = new NioEventLoopGroup(); + trackingAllocator = new TrackingNettyByteBufAllocator(); + } + + @AfterEach + void tearDown() { + if (stream != null && !stream.isClosed()) { + stream.close(); + } + if (eventLoopGroup != null) { + eventLoopGroup.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + } + System.gc(); + } + + private static boolean isSslEnabled() { + return getSslSettings().isEnabled(); + } + + // ========== Original Tests Converted from Spock ========== + + @Test + @Tag("Slow") + @DisabledIf("isSslEnabled") + @DisplayName("Should successfully connect when at least one IP address in the group is reachable") + void shouldSuccessfullyConnectWithWorkingIpAddressGroup() throws IOException { + // given + SocketSettings socketSettings = SocketSettings.builder() + .connectTimeout(1000, TimeUnit.MILLISECONDS) + .build(); + SslSettings sslSettings = SslSettings.builder().build(); + + InetAddressResolver inetAddressResolver = host -> { + try { + return asList( + InetAddress.getByName("192.168.255.255"), + InetAddress.getByName("1.2.3.4"), + InetAddress.getByName("127.0.0.1") + ); + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + }; + + factory = new NettyStreamFactory(inetAddressResolver, socketSettings, sslSettings, + eventLoopGroup, NioSocketChannel.class, trackingAllocator, null); + stream = factory.create(new ServerAddress()); + + // when + stream.open(OPERATION_CONTEXT); + + // then + assertFalse(stream.isClosed()); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @Tag("Slow") + @DisabledIf("isSslEnabled") + @DisplayName("Should throw MongoSocketOpenException when all IP addresses in the group are unreachable") + void shouldThrowExceptionWithNonWorkingIpAddressGroup() { + // given + SocketSettings socketSettings = SocketSettings.builder() + .connectTimeout(1000, TimeUnit.MILLISECONDS) + .build(); + SslSettings sslSettings = SslSettings.builder().build(); + + InetAddressResolver inetAddressResolver = host -> { + try { + return asList( + InetAddress.getByName("192.168.255.255"), + InetAddress.getByName("1.2.3.4"), + InetAddress.getByName("1.2.3.5") + ); + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + }; + + factory = new NettyStreamFactory(inetAddressResolver, socketSettings, sslSettings, + eventLoopGroup, NioSocketChannel.class, trackingAllocator, null); + stream = factory.create(new ServerAddress()); + + // when/then + assertThrows(MongoSocketOpenException.class, () -> stream.open(OPERATION_CONTEXT)); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should fail AsyncCompletionHandler when DNS name resolution fails") + void shouldFailAsyncCompletionHandlerIfNameResolutionFails() throws InterruptedException { + // given + ServerAddress serverAddress = new ServerAddress("nonexistent.invalid.hostname.test"); + MongoSocketException exception = new MongoSocketException("Temporary failure in name resolution", serverAddress); + + InetAddressResolver inetAddressResolver = host -> { + throw exception; + }; + + SocketSettings socketSettings = SocketSettings.builder() + .connectTimeout(1000, TimeUnit.MILLISECONDS) + .build(); + SslSettings sslSettings = SslSettings.builder().build(); + + factory = new NettyStreamFactory(inetAddressResolver, socketSettings, sslSettings, + eventLoopGroup, NioSocketChannel.class, trackingAllocator, null); + stream = factory.create(serverAddress); + + CallbackErrorHolder callback = new CallbackErrorHolder<>(); + + // when + stream.openAsync(OPERATION_CONTEXT, callback); + + // then + Throwable error = callback.getError(); + assertSame(exception, error); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); } + + // ========== New Tests for Defensive Improvements ========== + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should throw MongoSocketException when attempting to write to a closed stream") + void shouldThrowExceptionWhenWritingToClosedStream() throws IOException { + // given - open and then close stream + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + stream.close(); + + // when/then - write should fail with MongoSocketException + List buffers = createTestBuffers("test"); + + MongoSocketException exception = assertThrows(MongoSocketException.class, + () -> stream.write(buffers, OPERATION_CONTEXT)); + + // stream.write doesn't release the passed in buffers + ResourceUtil.release(buffers); + + assertTrue(exception.getMessage().contains("Stream is closed")); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should fail async write operation with clear error when stream is closed") + void shouldFailAsyncWriteWhenStreamIsClosed() throws IOException, InterruptedException { + // given - open and then close stream + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + stream.close(); + + // when - attempt async write + List buffers = createTestBuffers("test"); + + CallbackErrorHolder callback = new CallbackErrorHolder<>(); + stream.writeAsync(buffers, OPERATION_CONTEXT, callback); + + // then + Throwable error = callback.getError(); + assertNotNull(error); + assertInstanceOf(MongoSocketException.class, error); + assertTrue(error.getMessage().contains("Stream is closed")); + + // stream.writeAsync doesn't release the passed in buffers + ResourceUtil.release(buffers); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should throw exception when attempting to read from a closed stream") + void shouldThrowExceptionWhenReadingFromClosedStream() throws IOException { + // given - open and then close stream + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + stream.close(); + + // when/then - read should fail with IOException or MongoSocketException + Exception exception = assertThrows(Exception.class, + () -> stream.read(1024, OPERATION_CONTEXT)); + + assertTrue(exception instanceof IOException || exception instanceof MongoSocketException, + "Expected IOException or MongoSocketException but got: " + exception.getClass().getName()); + assertTrue(exception.getMessage().contains("closed") + || exception.getMessage().contains("Channel is not active") + || exception.getMessage().contains("connection")); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should handle multiple consecutive close() calls gracefully without errors") + void shouldHandleMultipleCloseCallsGracefully() throws IOException { + // given + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + + // when - close multiple times + stream.close(); + stream.close(); + stream.close(); + + // then - should not throw exception and stream should be closed + assertTrue(stream.isClosed()); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should prevent write operations after the underlying channel becomes inactive") + void shouldPreventWriteOperationsAfterChannelBecomesInactive() throws IOException { + // given - stream opened + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + + // Simulate channel becoming inactive by closing + stream.close(); + + // when/then - operations should fail + List buffers = createTestBuffers("test"); + + MongoSocketException exception = assertThrows(MongoSocketException.class, + () -> stream.write(buffers, OPERATION_CONTEXT)); + + // stream.write doesn't release the passed in buffers + ResourceUtil.release(buffers); + + assertTrue(exception.getMessage().contains("closed") + || exception.getMessage().contains("not active")); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should validate connection state before executing write operations") + void shouldValidateConnectionStateBeforeWrite() throws IOException { + // given - create stream but don't open it + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + + // when/then - write should fail because stream not opened + List buffers = createTestBuffers("test"); + + assertThrows(MongoSocketException.class, + () -> stream.write(buffers, OPERATION_CONTEXT)); + + // stream.write doesn't release the passed in buffers + ResourceUtil.release(buffers); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should handle async write failure gracefully with proper error callback") + void shouldHandleAsyncWriteFailureGracefully() throws InterruptedException { + // given - stream not opened + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + + // when - attempt async write without opening + List buffers = createTestBuffers("test"); + + CallbackErrorHolder callback = new CallbackErrorHolder<>(); + stream.writeAsync(buffers, OPERATION_CONTEXT, callback); + + // then - should fail with clear error + Throwable error = callback.getError(); + assertNotNull(error); + assertInstanceOf(MongoSocketException.class, error); + + // stream.writeAsync doesn't release the passed in buffers + ResourceUtil.release(buffers); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisplayName("Should return the correct server address from getAddress()") + void shouldReturnCorrectAddress() { + // given + ServerAddress serverAddress = new ServerAddress("localhost", 27017); + factory = createDefaultFactory(); + stream = factory.create(serverAddress); + + // when + ServerAddress address = stream.getAddress(); + + // then + assertEquals(serverAddress, address); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisplayName("Should correctly report stream closed state throughout its lifecycle") + void shouldReportClosedStateCorrectly() throws IOException { + // given + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + + // when/then - initially not closed (even without opening) + assertFalse(stream.isClosed()); + + // open and verify still not closed + stream.open(OPERATION_CONTEXT); + assertFalse(stream.isClosed()); + + // close and verify is closed + stream.close(); + assertTrue(stream.isClosed()); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should handle concurrent close() operations from multiple threads without errors") + void shouldHandleConcurrentCloseGracefully() throws IOException, InterruptedException { + // given + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + + // when - close from multiple threads + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch completionLatch = new CountDownLatch(3); + AtomicReference exceptionRef = new AtomicReference<>(); + + for (int i = 0; i < 3; i++) { + new Thread(() -> { + try { + startLatch.await(); + stream.close(); + } catch (Exception e) { + exceptionRef.set(e); + } finally { + completionLatch.countDown(); + } + }).start(); + } + + startLatch.countDown(); + assertTrue(completionLatch.await(5, TimeUnit.SECONDS)); + + // then - no exceptions should occur + assertNull(exceptionRef.get()); + assertTrue(stream.isClosed()); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should prevent ByteBuf leaks when write operations fail on a closed stream") + void shouldPreventBufferLeakWhenWriteFailsOnClosedStream() throws IOException, InterruptedException { + // given - create and open stream + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + + // Close the stream + stream.close(); + + // when - attempt multiple writes (to test buffer cleanup) + List buffers = createTestBuffers("test1", "test2"); + + CallbackErrorHolder callback1 = new CallbackErrorHolder<>(); + CallbackErrorHolder callback2 = new CallbackErrorHolder<>(); + CallbackErrorHolder callback3 = new CallbackErrorHolder<>(); + + stream.writeAsync(buffers, OPERATION_CONTEXT, callback1); + stream.writeAsync(buffers, OPERATION_CONTEXT, callback2); + stream.writeAsync(buffers, OPERATION_CONTEXT, callback3); + + // then - all should fail gracefully + assertNotNull(callback1.getError()); + assertNotNull(callback2.getError()); + assertNotNull(callback3.getError()); + + // stream.writeAsync doesn't release the passed in buffers + ResourceUtil.release(buffers); + + // Verify no buffer leaks - stream should have released buffers on failed async write + trackingAllocator.assertAllBuffersReleased(); + } + + // ========== Additional Comprehensive Test Scenarios ========== + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should not allocate composite buffer when channel is inactive before write") + void shouldNotAllocateBufferWhenChannelInactiveBeforeWrite() throws IOException, InterruptedException { + // given - create stream but immediately close it + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + stream.close(); + + // when - attempt async write with large buffers (would allocate resources) + List largeBuffers = createLargeTestBuffers(1024 * 1024, 2); + + CallbackErrorHolder callback = new CallbackErrorHolder<>(); + stream.writeAsync(largeBuffers, OPERATION_CONTEXT, callback); + + // then - should fail immediately without allocating composite buffer + Throwable error = callback.getError(); + assertNotNull(error); + assertTrue(error.getMessage().contains("closed") || error.getMessage().contains("not active")); + + // stream.writeAsync doesn't release the passed in buffers + ResourceUtil.release(largeBuffers); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should fail fast when attempting operations on never-opened stream") + void shouldFailFastOnNeverOpenedStream() { + // given - stream created but never opened + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + + // when/then - all operations should fail immediately + List buffers = createTestBuffers("test"); + + // Write should fail + assertThrows(MongoSocketException.class, + () -> stream.write(buffers, OPERATION_CONTEXT)); + + // stream.write doesn't release the passed in buffers + ResourceUtil.release(buffers); + + // Read should fail + assertThrows(Exception.class, + () -> stream.read(1024, OPERATION_CONTEXT)); + + // Stream should not be marked as closed (never opened) + assertFalse(stream.isClosed()); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should handle async read operation gracefully on closed stream") + void shouldHandleAsyncReadOnClosedStream() throws IOException, InterruptedException { + // given - open and close stream + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + stream.close(); + + // when - attempt async read + CallbackErrorHolder callback = new CallbackErrorHolder<>(); + stream.readAsync(1024, OPERATION_CONTEXT, callback); + + // then - should fail with appropriate error + Throwable error = callback.getError(); + assertNotNull(error); + assertTrue(error instanceof IOException || error instanceof MongoSocketException); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @SuppressWarnings("WriteOnlyObject") + @DisplayName("Should handle concurrent write and close operations without deadlock") + void shouldHandleConcurrentWriteAndCloseWithoutDeadlock() throws IOException, InterruptedException { + // given + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + + CountDownLatch writeLatch = new CountDownLatch(1); + CountDownLatch closeLatch = new CountDownLatch(1); + AtomicReference writeException = new AtomicReference<>(); + AtomicReference closeException = new AtomicReference<>(); + + // when - write and close concurrently + Thread writeThread = new Thread(() -> { + List buffers = createTestBuffers("concurrent test"); + try { + writeLatch.await(); + stream.write(buffers, OPERATION_CONTEXT); + } catch (Exception e) { + writeException.set(e); + } finally { + // stream.write doesn't release the passed in buffers + ResourceUtil.release(buffers); + } + }); + + Thread closeThread = new Thread(() -> { + try { + closeLatch.await(); + stream.close(); + } catch (Exception e) { + closeException.set(e); + } + }); + + writeThread.start(); + closeThread.start(); + + // Start both operations nearly simultaneously + writeLatch.countDown(); + closeLatch.countDown(); + + // Wait for completion + writeThread.join(5000); + closeThread.join(5000); + + // then - no deadlock should occur, operations complete + assertFalse(writeThread.isAlive(), "Write thread should complete"); + assertFalse(closeThread.isAlive(), "Close thread should complete"); + assertTrue(stream.isClosed(), "Stream should be closed"); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should handle empty buffer list in write operations") + void shouldHandleEmptyBufferList() throws IOException { + // given + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + + // when - write empty buffer list + List emptyBuffers = emptyList(); + + // then - should not throw exception (implementation specific behavior) + assertDoesNotThrow(() -> stream.write(emptyBuffers, OPERATION_CONTEXT)); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @Tag("Slow") + @DisabledIf("isSslEnabled") + @DisplayName("Should properly cleanup resources when async open operation fails") + void shouldCleanupResourcesWhenAsyncOpenFails() throws InterruptedException { + // given - factory with unreachable address + SocketSettings socketSettings = SocketSettings.builder() + .connectTimeout(500, TimeUnit.MILLISECONDS) // Longer timeout to allow connection attempt + .build(); + SslSettings sslSettings = SslSettings.builder().build(); + + InetAddressResolver resolver = host -> { + try { + return singletonList(InetAddress.getByName("192.168.255.255")); // Unreachable + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + }; + + factory = new NettyStreamFactory(resolver, socketSettings, sslSettings, + eventLoopGroup, NioSocketChannel.class, trackingAllocator, null); + stream = factory.create(new ServerAddress()); + + // when - attempt async open + CallbackErrorHolder callback = new CallbackErrorHolder<>(); + stream.openAsync(OPERATION_CONTEXT, callback); + + // then - should fail and cleanup resources (may take time due to connection timeout) + Throwable error = callback.getError(); + assertNotNull(error); + assertInstanceOf(MongoSocketException.class, error); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisplayName("Should return consistent ServerAddress across multiple calls") + void shouldReturnConsistentServerAddress() { + // given + ServerAddress serverAddress = new ServerAddress("test.mongodb.com", 27017); + factory = createDefaultFactory(); + stream = factory.create(serverAddress); + + // when - call getAddress multiple times + ServerAddress addr1 = stream.getAddress(); + ServerAddress addr2 = stream.getAddress(); + ServerAddress addr3 = stream.getAddress(); + + // then - all should return same address + assertEquals(serverAddress, addr1); + assertEquals(serverAddress, addr2); + assertEquals(serverAddress, addr3); + assertSame(addr1, addr2); // Should return same instance + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should handle rapid open and close cycles without resource leaks") + void shouldHandleRapidOpenCloseCycles() throws IOException { + // given + factory = createDefaultFactory(); + + // when - perform multiple rapid open/close cycles + for (int i = 0; i < 5; i++) { + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + assertFalse(stream.isClosed(), "Stream should be open after opening"); + + stream.close(); + assertTrue(stream.isClosed(), "Stream should be closed after closing"); + } + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should validate stream state remains consistent after failed operations") + void shouldMaintainConsistentStateAfterFailedOperations() throws IOException { + // given + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + ServerAddress originalAddress = stream.getAddress(); + + // when - perform operation that will fail + stream.close(); + + List buffers = createTestBuffers("test"); + + try { + stream.write(buffers, OPERATION_CONTEXT); + } catch (MongoSocketException e) { + // Expected + } finally { + // stream.write doesn't release the passed in buffers + ResourceUtil.release(buffers); + } + + // then - stream state should remain consistent + assertTrue(stream.isClosed()); + assertEquals(originalAddress, stream.getAddress()); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should release retained buffers when stream closes during read assembly") + void shouldReleaseBuffersWhenStreamClosesDuringReadAssembly() throws IOException, InterruptedException { + // This test simulates the race condition where: + // 1. Async read is initiated (creates pendingReader) + // 2. Stream is closed by another thread before data arrives + // 3. If data were to arrive and be retained, it must be cleaned up + // 4. The fix ensures readAsync() releases buffers when detecting closed state + + // given - open stream and initiate async read + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + + // Start an async read that will wait for data (creates pendingReader) + CallbackErrorHolder asyncReadCallback = new CallbackErrorHolder<>(); + stream.readAsync(1024, OPERATION_CONTEXT, asyncReadCallback); + + // Give the async read time to set up + Thread.sleep(10); + + // when - close the stream while async read is pending + stream.close(); + + // then - the async read should fail with appropriate error + Throwable error = asyncReadCallback.getError(); + assertNotNull(error, "Async read should fail when stream is closed"); + assertTrue(error instanceof IOException || error instanceof MongoSocketException, + "Expected IOException or MongoSocketException but got: " + error.getClass().getName()); + assertTrue(stream.isClosed(), "Stream should be closed"); + + // Attempt another read after close - should also fail and clean up any pending buffers + Exception exception = assertThrows(Exception.class, + () -> stream.read(1024, OPERATION_CONTEXT)); + assertTrue(exception instanceof IOException || exception instanceof MongoSocketException); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @DisplayName("Should handle rapid close during async read operation with race condition") + void shouldHandleCloseRaceWithConcurrentReads() throws IOException, InterruptedException { + // This test exercises the specific race condition from the SSL/TLS leak: + // Thread 1: Initiates async read (creates pendingReader) + // Thread 2: Closes the stream + // Thread 1: If data arrives, handleReadResponse would retain buffer + // Thread 1: readAsync detects closed state and must clean up + + // given + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + + // Start an async read that will wait for data + CallbackErrorHolder callback = new CallbackErrorHolder<>(); + + CountDownLatch readStarted = new CountDownLatch(1); + CountDownLatch closeStarted = new CountDownLatch(1); + AtomicReference closeException = new AtomicReference<>(); + + // Thread 1: Start async read + Thread readThread = new Thread(() -> { + try { + readStarted.countDown(); + stream.readAsync(1024, OPERATION_CONTEXT, callback); + closeStarted.await(); // Wait for close to happen + } catch (Exception e) { + // Expected - stream may be closed + } + }); + + // Thread 2: Close the stream + Thread closeThread = new Thread(() -> { + try { + readStarted.await(); // Wait for read to start + Thread.sleep(10); // Small delay to increase race window + closeStarted.countDown(); + stream.close(); + } catch (Exception e) { + closeException.set(e); + } + }); + + // when - execute both threads + readThread.start(); + closeThread.start(); + + readThread.join(2000); + closeThread.join(2000); + + // then - threads should complete + assertFalse(readThread.isAlive(), "Read thread should complete"); + assertFalse(closeThread.isAlive(), "Close thread should complete"); + assertNull(closeException.get(), "Close should not throw exception"); + assertTrue(stream.isClosed(), "Stream should be closed"); + + // The async read callback should either: + // - Receive an error (if it noticed the close), OR + // - Still be waiting (if close happened first and prevented read setup) + // Either way, no buffer leaks should occur + + // Try another operation to ensure state is consistent + assertThrows(Exception.class, () -> stream.read(100, OPERATION_CONTEXT), + "Operations after close should fail"); + + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + @Test + @DisabledIf("isSslEnabled") + @SuppressWarnings("ThrowableNotThrown") + @DisplayName("Should handle multiple async operations with different callbacks") + void shouldHandleMultipleAsyncOperationsWithDifferentCallbacks() throws IOException, InterruptedException { + // given + factory = createDefaultFactory(); + stream = factory.create(new ServerAddress()); + stream.open(OPERATION_CONTEXT); + + List buffers = createTestBuffers("async1"); + + // when - submit multiple async writes with different callbacks + CallbackErrorHolder callback1 = new CallbackErrorHolder<>(); + CallbackErrorHolder callback2 = new CallbackErrorHolder<>(); + CallbackErrorHolder callback3 = new CallbackErrorHolder<>(); + + stream.writeAsync(buffers, OPERATION_CONTEXT, callback1); + stream.writeAsync(buffers, OPERATION_CONTEXT, callback2); + stream.writeAsync(buffers, OPERATION_CONTEXT, callback3); + + // then - all callbacks should be invoked (either success or failure) + // Wait for all to complete + callback1.getError(); // May be null (success) or error + callback2.getError(); + callback3.getError(); + + // stream.writeAsync doesn't release the passed in buffers + ResourceUtil.release(buffers); + + // Test passes if we reach here without timeout + // Verify no buffer leaks + trackingAllocator.assertAllBuffersReleased(); + } + + // ========== Helper Methods ========== + + /** + * Creates ByteBufs with the given data, wrapped in NettyByteBuf. + */ + private List createTestBuffers(final String... dataArray) { + List buffers = new ArrayList<>(); + for (String data : dataArray) { + io.netty.buffer.ByteBuf nettyBuffer = trackingAllocator.buffer(); + nettyBuffer.writeBytes(data.getBytes()); + buffers.add(new NettyByteBuf(nettyBuffer)); + } + return buffers; + } + + /** + * Creates multiple large ByteBufs with specified capacity, wrapped in NettyByteBuf. + */ + private List createLargeTestBuffers(final int capacityBytes, final int count) { + List buffers = new ArrayList<>(); + for (int i = 0; i < count; i++) { + io.netty.buffer.ByteBuf nettyBuffer = trackingAllocator.buffer(capacityBytes); + buffers.add(new NettyByteBuf(nettyBuffer)); + } + return buffers; + } + + private NettyStreamFactory createDefaultFactory() { + SocketSettings socketSettings = SocketSettings.builder() + .connectTimeout(10000, TimeUnit.MILLISECONDS) + .build(); + SslSettings sslSettings = SslSettings.builder().build(); + + InetAddressResolver resolver = host -> { + try { + return singletonList(InetAddress.getByName(host)); + } catch (UnknownHostException e) { + throw new RuntimeException(e); + } + }; + + // Use tracking allocator for all tests to get explicit leak verification + return new NettyStreamFactory(resolver, socketSettings, sslSettings, + eventLoopGroup, NioSocketChannel.class, trackingAllocator, null); + } + + /** + * Helper class to capture async completion handler results. + */ + private static class CallbackErrorHolder implements AsyncCompletionHandler { + private final CountDownLatch latch = new CountDownLatch(1); + private final AtomicReference throwableRef = new AtomicReference<>(); + private final AtomicReference resultRef = new AtomicReference<>(); + + Throwable getError() throws InterruptedException { + assertTrue(latch.await(5, TimeUnit.SECONDS), "Callback not completed within timeout"); + return throwableRef.get(); + } + + T getResult() throws InterruptedException { + assertTrue(latch.await(5, TimeUnit.SECONDS), "Callback not completed within timeout"); + return resultRef.get(); + } + + @Override + public void completed(final T result) { + resultRef.set(result); + latch.countDown(); + } + + @Override + public void failed(@NotNull final Throwable t) { + throwableRef.set(t); + latch.countDown(); + } + } + + /** + * Tracking allocator that records all buffer allocations and can verify they are all released. + * This allows us to explicitly prove no buffer leaks occur, rather than relying solely on + * Netty's leak detector. + */ + private static class TrackingNettyByteBufAllocator extends PooledByteBufAllocator { + private final List allocatedBuffers = new ArrayList<>(); + private final AtomicInteger allocationCount = new AtomicInteger(0); + + TrackingNettyByteBufAllocator() { + super(false); // Use heap buffers for testing + } + + @Override + public io.netty.buffer.CompositeByteBuf compositeBuffer(final int maxNumComponents) { + io.netty.buffer.CompositeByteBuf buffer = super.compositeBuffer(maxNumComponents); + trackBuffer(buffer); + return buffer; + } + + @Override + public io.netty.buffer.CompositeByteBuf compositeDirectBuffer(final int maxNumComponents) { + io.netty.buffer.CompositeByteBuf buffer = super.compositeDirectBuffer(maxNumComponents); + trackBuffer(buffer); + return buffer; + } + + @Override + public io.netty.buffer.ByteBuf buffer() { + io.netty.buffer.ByteBuf buffer = super.buffer(); + trackBuffer(buffer); + return buffer; + } + + @Override + public io.netty.buffer.ByteBuf buffer(final int initialCapacity) { + io.netty.buffer.ByteBuf buffer = super.buffer(initialCapacity); + trackBuffer(buffer); + return buffer; + } + + @Override + public io.netty.buffer.ByteBuf buffer(final int initialCapacity, final int maxCapacity) { + io.netty.buffer.ByteBuf buffer = super.buffer(initialCapacity, maxCapacity); + trackBuffer(buffer); + return buffer; + } + + @Override + public io.netty.buffer.ByteBuf directBuffer(final int initialCapacity, final int maxCapacity) { + io.netty.buffer.ByteBuf buffer = super.directBuffer(initialCapacity, maxCapacity); + trackBuffer(buffer); + return buffer; + } + + private void trackBuffer(final io.netty.buffer.ByteBuf buffer) { + synchronized (allocatedBuffers) { + allocatedBuffers.add(buffer); + allocationCount.incrementAndGet(); + } + } + + /** + * Asserts that all allocated buffers have been released (refCnt == 0). + * This provides explicit proof that no leaks occurred. + */ + void assertAllBuffersReleased() { + synchronized (allocatedBuffers) { + List leakedBuffers = new ArrayList<>(); + for (io.netty.buffer.ByteBuf buffer : allocatedBuffers) { + if (buffer.refCnt() > 0) { + leakedBuffers.add(buffer); + } + } + + if (!leakedBuffers.isEmpty()) { + StringBuilder message = new StringBuilder(); + message.append("BUFFER LEAK DETECTED: ") + .append(leakedBuffers.size()) + .append(" of ") + .append(allocatedBuffers.size()) + .append(" buffers were not released\n"); + + message.append(getStats()) + .append("\n"); + + for (int i = 0; i < leakedBuffers.size(); i++) { + io.netty.buffer.ByteBuf leaked = leakedBuffers.get(i); + message.append(" [").append(i).append("] ") + .append(leaked.getClass().getSimpleName()) + .append(" refCnt=").append(leaked.refCnt()) + .append(" capacity=").append(leaked.capacity()) + .append("\n"); + } + + leakedBuffers.forEach(ReferenceCounted::release); + fail(message.toString()); + } + } + } + + /** + * Returns allocation statistics for debugging. + */ + String getStats() { + synchronized (allocatedBuffers) { + int leaked = 0; + int released = 0; + for (io.netty.buffer.ByteBuf buffer : allocatedBuffers) { + if (buffer.refCnt() > 0) { + leaked++; + } else { + released++; + } + } + return String.format("Allocations: %d, Released: %d, Leaked: %d", + allocationCount.get(), released, leaked); + } + } + } +} diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy index 0407baeca8a..5bab4bad335 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy @@ -47,6 +47,9 @@ class ReplyHeaderSpecification extends Specification { replyHeader.requestId == 45 replyHeader.responseTo == 23 + cleanup: + byteBuf?.release() + where: responseFlags << [0, 1, 2, 3] cursorNotFound << [false, true, false, true] @@ -79,6 +82,9 @@ class ReplyHeaderSpecification extends Specification { replyHeader.requestId == 45 replyHeader.responseTo == 23 + cleanup: + byteBuf?.release() + where: responseFlags << [0, 1, 2, 3] cursorNotFound << [false, true, false, true] @@ -105,6 +111,9 @@ class ReplyHeaderSpecification extends Specification { then: def ex = thrown(MongoInternalException) ex.getMessage() == 'Unexpected reply message opCode 2' + + cleanup: + byteBuf?.release() } def 'should throw MongoInternalException on message size < 36'() { @@ -127,6 +136,9 @@ class ReplyHeaderSpecification extends Specification { then: def ex = thrown(MongoInternalException) ex.getMessage() == 'The reply message length 35 is less than the minimum message length 36' + + cleanup: + byteBuf?.release() } def 'should throw MongoInternalException on message size > max message size'() { @@ -149,6 +161,9 @@ class ReplyHeaderSpecification extends Specification { then: def ex = thrown(MongoInternalException) ex.getMessage() == 'The reply message length 400 is greater than the maximum message length 399' + + cleanup: + byteBuf?.release() } def 'should throw MongoInternalException on num documents < 0'() { @@ -171,6 +186,9 @@ class ReplyHeaderSpecification extends Specification { then: def ex = thrown(MongoInternalException) ex.getMessage() == 'The reply message number of returned documents, -1, is expected to be 1' + + cleanup: + byteBuf?.release() } def 'should throw MongoInternalException on num documents < 0 with compressed header'() { @@ -197,5 +215,8 @@ class ReplyHeaderSpecification extends Specification { then: def ex = thrown(MongoInternalException) ex.getMessage() == 'The reply message number of returned documents, -1, is expected to be 1' + + cleanup: + byteBuf?.release() } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentSpecification.groovy deleted file mode 100644 index 8dc599706a9..00000000000 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentSpecification.groovy +++ /dev/null @@ -1,313 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.internal.connection - -import org.bson.BsonArray -import org.bson.BsonBinaryWriter -import org.bson.BsonBoolean -import org.bson.BsonDocument -import org.bson.BsonInt32 -import org.bson.BsonNull -import org.bson.BsonValue -import org.bson.ByteBuf -import org.bson.ByteBufNIO -import org.bson.codecs.BsonDocumentCodec -import org.bson.codecs.DecoderContext -import org.bson.codecs.EncoderContext -import org.bson.io.BasicOutputBuffer -import org.bson.json.JsonMode -import org.bson.json.JsonWriterSettings -import spock.lang.Specification - -import java.nio.ByteBuffer - -import static java.util.Arrays.asList - -class ByteBufBsonDocumentSpecification extends Specification { - def emptyDocumentByteBuf = new ByteBufNIO(ByteBuffer.wrap([5, 0, 0, 0, 0] as byte[])) - ByteBuf documentByteBuf - ByteBufBsonDocument emptyByteBufDocument = new ByteBufBsonDocument(emptyDocumentByteBuf) - def document = new BsonDocument() - .append('a', new BsonInt32(1)) - .append('b', new BsonInt32(2)) - .append('c', new BsonDocument('x', BsonBoolean.TRUE)) - .append('d', new BsonArray(asList(new BsonDocument('y', BsonBoolean.FALSE), new BsonInt32(1)))) - - ByteBufBsonDocument byteBufDocument - - def setup() { - def buffer = new BasicOutputBuffer() - new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()) - ByteArrayOutputStream baos = new ByteArrayOutputStream() - buffer.pipe(baos) - documentByteBuf = new ByteBufNIO(ByteBuffer.wrap(baos.toByteArray())) - byteBufDocument = new ByteBufBsonDocument(documentByteBuf) - } - - def 'get should get the value of the given key'() { - expect: - emptyByteBufDocument.get('a') == null - byteBufDocument.get('z') == null - byteBufDocument.get('a') == new BsonInt32(1) - byteBufDocument.get('b') == new BsonInt32(2) - } - - def 'get should throw if the key is null'() { - when: - byteBufDocument.get(null) - - then: - thrown(IllegalArgumentException) - documentByteBuf.referenceCount == 1 - } - - def 'containKey should throw if the key name is null'() { - when: - byteBufDocument.containsKey(null) - - then: - thrown(IllegalArgumentException) - documentByteBuf.referenceCount == 1 - } - - def 'containsKey should find an existing key'() { - expect: - byteBufDocument.containsKey('a') - byteBufDocument.containsKey('b') - byteBufDocument.containsKey('c') - byteBufDocument.containsKey('d') - documentByteBuf.referenceCount == 1 - } - - def 'containsKey should not find a non-existing key'() { - expect: - !byteBufDocument.containsKey('e') - !byteBufDocument.containsKey('x') - !byteBufDocument.containsKey('y') - documentByteBuf.referenceCount == 1 - } - - def 'containValue should find an existing value'() { - expect: - byteBufDocument.containsValue(document.get('a')) - byteBufDocument.containsValue(document.get('b')) - byteBufDocument.containsValue(document.get('c')) - byteBufDocument.containsValue(document.get('d')) - documentByteBuf.referenceCount == 1 - } - - def 'containValue should not find a non-existing value'() { - expect: - !byteBufDocument.containsValue(new BsonInt32(3)) - !byteBufDocument.containsValue(new BsonDocument('e', BsonBoolean.FALSE)) - !byteBufDocument.containsValue(new BsonArray(asList(new BsonInt32(2), new BsonInt32(4)))) - documentByteBuf.referenceCount == 1 - } - - def 'isEmpty should return false when the document is not empty'() { - expect: - !byteBufDocument.isEmpty() - documentByteBuf.referenceCount == 1 - } - - def 'isEmpty should return true when the document is empty'() { - expect: - emptyByteBufDocument.isEmpty() - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'should get correct size'() { - expect: - emptyByteBufDocument.size() == 0 - byteBufDocument.size() == 4 - documentByteBuf.referenceCount == 1 - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'should get correct key set'() { - expect: - emptyByteBufDocument.keySet().isEmpty() - byteBufDocument.keySet() == ['a', 'b', 'c', 'd'] as Set - documentByteBuf.referenceCount == 1 - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'should get correct values set'() { - expect: - emptyByteBufDocument.values().isEmpty() - byteBufDocument.values() as Set == [document.get('a'), document.get('b'), document.get('c'), document.get('d')] as Set - documentByteBuf.referenceCount == 1 - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'should get correct entry set'() { - expect: - emptyByteBufDocument.entrySet().isEmpty() - byteBufDocument.entrySet() == [new TestEntry('a', document.get('a')), - new TestEntry('b', document.get('b')), - new TestEntry('c', document.get('c')), - new TestEntry('d', document.get('d'))] as Set - documentByteBuf.referenceCount == 1 - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'all write methods should throw UnsupportedOperationException'() { - when: - byteBufDocument.clear() - - then: - thrown(UnsupportedOperationException) - - when: - byteBufDocument.put('x', BsonNull.VALUE) - - then: - thrown(UnsupportedOperationException) - - when: - byteBufDocument.append('x', BsonNull.VALUE) - - then: - thrown(UnsupportedOperationException) - - when: - byteBufDocument.putAll(new BsonDocument('x', BsonNull.VALUE)) - - then: - thrown(UnsupportedOperationException) - - when: - byteBufDocument.remove(BsonNull.VALUE) - - then: - thrown(UnsupportedOperationException) - } - - def 'should get first key'() { - expect: - byteBufDocument.getFirstKey() == document.keySet().iterator().next() - documentByteBuf.referenceCount == 1 - } - - def 'getFirstKey should throw NoSuchElementException if the document is empty'() { - when: - emptyByteBufDocument.getFirstKey() - - then: - thrown(NoSuchElementException) - emptyDocumentByteBuf.referenceCount == 1 - } - - def 'should create BsonReader'() { - when: - def reader = document.asBsonReader() - - then: - new BsonDocumentCodec().decode(reader, DecoderContext.builder().build()) == document - - cleanup: - reader.close() - } - - def 'clone should make a deep copy'() { - when: - BsonDocument cloned = byteBufDocument.clone() - - then: - cloned == byteBufDocument - documentByteBuf.referenceCount == 1 - } - - def 'should serialize and deserialize'() { - given: - def baos = new ByteArrayOutputStream() - def oos = new ObjectOutputStream(baos) - - when: - oos.writeObject(byteBufDocument) - def bais = new ByteArrayInputStream(baos.toByteArray()) - def ois = new ObjectInputStream(bais) - def deserializedDocument = ois.readObject() - - then: - byteBufDocument == deserializedDocument - documentByteBuf.referenceCount == 1 - } - - def 'toJson should return equivalent'() { - expect: - document.toJson() == byteBufDocument.toJson() - documentByteBuf.referenceCount == 1 - } - - def 'toJson should be callable multiple times'() { - expect: - byteBufDocument.toJson() - byteBufDocument.toJson() - documentByteBuf.referenceCount == 1 - } - - def 'size should be callable multiple times'() { - expect: - byteBufDocument.size() - byteBufDocument.size() - documentByteBuf.referenceCount == 1 - } - - def 'toJson should respect JsonWriteSettings'() { - given: - def settings = JsonWriterSettings.builder().outputMode(JsonMode.SHELL).build() - - expect: - document.toJson(settings) == byteBufDocument.toJson(settings) - } - - def 'toJson should return equivalent when a ByteBufBsonDocument is nested in a BsonDocument'() { - given: - def topLevel = new BsonDocument('nested', byteBufDocument) - - expect: - new BsonDocument('nested', document).toJson() == topLevel.toJson() - } - - class TestEntry implements Map.Entry { - - private final String key - private BsonValue value - - TestEntry(String key, BsonValue value) { - this.key = key - this.value = value - } - - @Override - String getKey() { - key - } - - @Override - BsonValue getValue() { - value - } - - @Override - BsonValue setValue(final BsonValue value) { - this.value = value - } - } - -} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java new file mode 100644 index 00000000000..e12c091b01f --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufBsonDocumentTest.java @@ -0,0 +1,706 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import org.bson.BsonArray; +import org.bson.BsonBinaryWriter; +import org.bson.BsonBoolean; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonReader; +import org.bson.BsonString; +import org.bson.BsonValue; +import org.bson.ByteBuf; +import org.bson.ByteBufNIO; +import org.bson.RawBsonDocument; +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.DecoderContext; +import org.bson.codecs.EncoderContext; +import org.bson.io.BasicOutputBuffer; +import org.bson.json.JsonMode; +import org.bson.json.JsonWriterSettings; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@DisplayName("ByteBufBsonDocument") +class ByteBufBsonDocumentTest { + private ByteBuf emptyDocumentByteBuf; + private ByteBuf documentByteBuf; + private ByteBufBsonDocument emptyByteBufDocument; + private ByteBufBsonDocument byteBufDocument; + private BsonDocument document; + + @BeforeEach + void setUp() { + emptyDocumentByteBuf = new ByteBufNIO(ByteBuffer.wrap(new byte[]{5, 0, 0, 0, 0})); + emptyByteBufDocument = new ByteBufBsonDocument(emptyDocumentByteBuf); + + document = new BsonDocument() + .append("a", new BsonInt32(1)) + .append("b", new BsonInt32(2)) + .append("c", new BsonDocument("x", BsonBoolean.TRUE)) + .append("d", new BsonArray(asList( + new BsonDocument("y", BsonBoolean.FALSE), + new BsonInt32(1) + ))); + + RawBsonDocument rawBsonDocument = RawBsonDocument.parse(document.toString()); + documentByteBuf = rawBsonDocument.getByteBuffer(); + byteBufDocument = new ByteBufBsonDocument(documentByteBuf); + } + + @AfterEach + void tearDown() { + if (byteBufDocument != null) { + byteBufDocument.close(); + } + if (emptyByteBufDocument != null) { + emptyByteBufDocument.close(); + } + } + + // Basic Operations + + @Test + @DisplayName("get() returns value for existing key, null for missing key") + void getShouldReturnCorrectValue() { + assertNull(emptyByteBufDocument.get("a")); + assertNull(byteBufDocument.get("z")); + assertEquals(new BsonInt32(1), byteBufDocument.get("a")); + assertEquals(new BsonInt32(2), byteBufDocument.get("b")); + } + + @Test + @DisplayName("get() throws IllegalArgumentException for null key") + void getShouldThrowForNullKey() { + assertThrows(IllegalArgumentException.class, () -> byteBufDocument.get(null)); + assertEquals(1, documentByteBuf.getReferenceCount()); + } + + @Test + @DisplayName("containsKey() finds existing keys and rejects missing keys") + void containsKeyShouldWork() { + assertThrows(IllegalArgumentException.class, () -> byteBufDocument.containsKey(null)); + assertTrue(byteBufDocument.containsKey("a")); + assertTrue(byteBufDocument.containsKey("d")); + assertFalse(byteBufDocument.containsKey("z")); + assertEquals(1, documentByteBuf.getReferenceCount()); + } + + @Test + @DisplayName("containsValue() finds existing values and rejects missing values") + void containsValueShouldWork() { + assertTrue(byteBufDocument.containsValue(document.get("a"))); + assertTrue(byteBufDocument.containsValue(document.get("c"))); + assertFalse(byteBufDocument.containsValue(new BsonInt32(999))); + assertEquals(1, documentByteBuf.getReferenceCount()); + } + + @Test + @DisplayName("isEmpty() returns correct result") + void isEmptyShouldWork() { + assertTrue(emptyByteBufDocument.isEmpty()); + assertFalse(byteBufDocument.isEmpty()); + } + + @Test + @DisplayName("size() returns correct count") + void sizeShouldWork() { + assertEquals(0, emptyByteBufDocument.size()); + assertEquals(4, byteBufDocument.size()); + assertEquals(4, byteBufDocument.size()); // Verify caching works + } + + @Test + @DisplayName("getFirstKey() returns first key or throws for empty document") + void getFirstKeyShouldWork() { + assertEquals("a", byteBufDocument.getFirstKey()); + assertThrows(NoSuchElementException.class, () -> emptyByteBufDocument.getFirstKey()); + } + + // Collection Views + + @Test + @DisplayName("keySet() returns all keys") + void keySetShouldWork() { + assertTrue(emptyByteBufDocument.keySet().isEmpty()); + assertEquals(new HashSet<>(asList("a", "b", "c", "d")), byteBufDocument.keySet()); + } + + @Test + @DisplayName("values() returns all values") + void valuesShouldWork() { + assertTrue(emptyByteBufDocument.values().isEmpty()); + Set expected = new HashSet<>(asList( + document.get("a"), document.get("b"), document.get("c"), document.get("d") + )); + assertEquals(expected, new HashSet<>(byteBufDocument.values())); + } + + @Test + @DisplayName("entrySet() returns all entries") + void entrySetShouldWork() { + assertTrue(emptyByteBufDocument.entrySet().isEmpty()); + Set> expected = new HashSet<>(asList( + new AbstractMap.SimpleImmutableEntry<>("a", document.get("a")), + new AbstractMap.SimpleImmutableEntry<>("b", document.get("b")), + new AbstractMap.SimpleImmutableEntry<>("c", document.get("c")), + new AbstractMap.SimpleImmutableEntry<>("d", document.get("d")) + )); + assertEquals(expected, byteBufDocument.entrySet()); + } + + // Type-Specific Accessors + + @Test + @DisplayName("getDocument() returns nested document") + void getDocumentShouldWork() { + BsonDocument nested = byteBufDocument.getDocument("c"); + assertNotNull(nested); + assertEquals(BsonBoolean.TRUE, nested.get("x")); + } + + @Test + @DisplayName("getArray() returns array") + void getArrayShouldWork() { + BsonArray array = byteBufDocument.getArray("d"); + assertNotNull(array); + assertEquals(2, array.size()); + } + + @Test + @DisplayName("get() with default value works correctly") + void getWithDefaultShouldWork() { + assertEquals(new BsonInt32(1), byteBufDocument.get("a", new BsonInt32(999))); + assertEquals(new BsonInt32(999), byteBufDocument.get("missing", new BsonInt32(999))); + } + + @Test + @DisplayName("Type check methods return correct results") + void typeChecksShouldWork() { + assertTrue(byteBufDocument.isNumber("a")); + assertTrue(byteBufDocument.isInt32("a")); + assertTrue(byteBufDocument.isDocument("c")); + assertTrue(byteBufDocument.isArray("d")); + assertFalse(byteBufDocument.isDocument("a")); + } + + // Immutability + + @Test + @DisplayName("All write methods throw UnsupportedOperationException") + void writeMethodsShouldThrow() { + assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.clear()); + assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.put("x", new BsonInt32(1))); + assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.append("x", new BsonInt32(1))); + assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.putAll(new BsonDocument())); + assertThrows(UnsupportedOperationException.class, () -> byteBufDocument.remove("a")); + } + + // Conversion and Serialization + + @Test + @DisplayName("toBsonDocument() returns equivalent document and caches result") + void toBsonDocumentShouldWork() { + assertEquals(document, byteBufDocument.toBsonDocument()); + BsonDocument first = byteBufDocument.toBsonDocument(); + BsonDocument second = byteBufDocument.toBsonDocument(); + assertEquals(first, second); + } + + @Test + @DisplayName("asBsonReader() creates valid reader") + void asBsonReaderShouldWork() { + try (BsonReader reader = byteBufDocument.asBsonReader()) { + BsonDocument decoded = new BsonDocumentCodec().decode(reader, DecoderContext.builder().build()); + assertEquals(document, decoded); + } + } + + @Test + @DisplayName("toJson() returns correct JSON with different settings") + void toJsonShouldWork() { + assertEquals(document.toJson(), byteBufDocument.toJson()); + assertNotNull(byteBufDocument.toJson()); // Verify caching + + JsonWriterSettings shellSettings = JsonWriterSettings.builder().outputMode(JsonMode.SHELL).build(); + assertEquals(document.toJson(shellSettings), byteBufDocument.toJson(shellSettings)); + } + + @Test + @DisplayName("toString() returns equivalent string") + void toStringShouldWork() { + assertEquals(document.toString(), byteBufDocument.toString()); + } + + @Test + @DisplayName("clone() creates deep copy") + void cloneShouldWork() { + BsonDocument cloned = byteBufDocument.clone(); + assertEquals(byteBufDocument, cloned); + } + + @Test + @DisplayName("Java serialization works correctly") + void serializationShouldWork() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + new ObjectOutputStream(baos).writeObject(byteBufDocument); + Object deserialized = new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray())).readObject(); + assertEquals(byteBufDocument, deserialized); + } + + // Equality and HashCode + + @Test + @DisplayName("equals() and hashCode() work correctly") + void equalsAndHashCodeShouldWork() { + assertEquals(document, byteBufDocument); + assertEquals(byteBufDocument, document); + assertEquals(document.hashCode(), byteBufDocument.hashCode()); + assertNotEquals(byteBufDocument, new BsonDocument("x", new BsonInt32(99))); + } + + // Resource Management + + @Test + @DisplayName("Closed document throws IllegalStateException on all operations") + void closedDocumentShouldThrow() { + byteBufDocument.close(); + assertThrows(IllegalStateException.class, () -> byteBufDocument.size()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.isEmpty()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.containsKey("a")); + assertThrows(IllegalStateException.class, () -> byteBufDocument.get("a")); + assertThrows(IllegalStateException.class, () -> byteBufDocument.keySet()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.values()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.entrySet()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.getFirstKey()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.toBsonDocument()); + assertThrows(IllegalStateException.class, () -> byteBufDocument.toJson()); + } + + @Test + @DisplayName("close() can be called multiple times safely") + void closeIsIdempotent() { + byteBufDocument.close(); + byteBufDocument.close(); // Should not throw + } + + @Test + @DisplayName("Nested documents are closed when parent is closed") + void nestedDocumentsClosedWithParent() { + BsonDocument doc = new BsonDocument("outer", new BsonDocument("inner", new BsonInt32(42))); + ByteBuf buf = createByteBufFromDocument(doc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + BsonDocument retrieved = byteBufDoc.getDocument("outer"); + byteBufDoc.close(); + + assertThrows(IllegalStateException.class, byteBufDoc::size); + if (retrieved instanceof ByteBufBsonDocument) { + assertThrows(IllegalStateException.class, retrieved::size); + } + } + + @Test + @DisplayName("Nested arrays are closed when parent is closed") + void nestedArraysClosedWithParent() { + BsonDocument doc = new BsonDocument("arr", new BsonArray(asList( + new BsonInt32(1), new BsonDocument("x", new BsonInt32(2)) + ))); + ByteBuf buf = createByteBufFromDocument(doc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + BsonArray retrieved = byteBufDoc.getArray("arr"); + byteBufDoc.close(); + + assertThrows(IllegalStateException.class, byteBufDoc::size); + if (retrieved instanceof ByteBufBsonArray) { + assertThrows(IllegalStateException.class, retrieved::size); + } + } + + @Test + @DisplayName("Deeply nested structures are closed recursively") + void deeplyNestedClosedRecursively() { + BsonDocument doc = new BsonDocument() + .append("level1", new BsonArray(asList( + new BsonDocument("level2", new BsonDocument("level3", new BsonInt32(999))), + new BsonInt32(1) + ))) + .append("sibling", new BsonDocument("key", new BsonString("value"))); + + ByteBuf buf = createByteBufFromDocument(doc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + BsonArray level1 = byteBufDoc.getArray("level1"); + byteBufDoc.getDocument("sibling"); + + if (level1.get(0).isDocument()) { + BsonDocument level2Doc = level1.get(0).asDocument(); + if (level2Doc.containsKey("level2")) { + assertEquals(new BsonInt32(999), level2Doc.getDocument("level2").get("level3")); + } + } + + byteBufDoc.close(); + assertThrows(IllegalStateException.class, byteBufDoc::size); + } + + @Test + @DisplayName("Iteration tracks resources correctly") + void iterationTracksResources() { + BsonDocument doc = new BsonDocument() + .append("doc1", new BsonDocument("a", new BsonInt32(1))) + .append("arr1", new BsonArray(asList(new BsonInt32(2), new BsonInt32(3)))) + .append("primitive", new BsonString("test")); + + ByteBuf buf = createByteBufFromDocument(doc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + int count = 0; + for (Map.Entry entry : byteBufDoc.entrySet()) { + assertNotNull(entry.getKey()); + assertNotNull(entry.getValue()); + count++; + } + assertEquals(3, count); + + byteBufDoc.close(); + assertThrows(IllegalStateException.class, byteBufDoc::size); + } + + @Test + @DisplayName("toBsonDocument() handles nested structures and allows close") + void toBsonDocumentHandlesNestedStructures() { + BsonDocument complexDoc = new BsonDocument() + .append("doc", new BsonDocument("x", new BsonInt32(1))) + .append("arr", new BsonArray(asList(new BsonDocument("y", new BsonInt32(2)), new BsonInt32(3)))); + + ByteBuf buf = createByteBufFromDocument(complexDoc); + ByteBufBsonDocument byteBufDoc = new ByteBufBsonDocument(buf); + + BsonDocument hydrated = byteBufDoc.toBsonDocument(); + assertEquals(complexDoc, hydrated); + + byteBufDoc.close(); + assertThrows(IllegalStateException.class, byteBufDoc::size); + } + + // Sequence Fields (OP_MSG) + + @Test + @DisplayName("Sequence field is accessible as array of ByteBufBsonDocuments") + void sequenceFieldAccessibleAsArray() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 3)) { + + BsonValue documentsValue = commandDoc.get("documents"); + assertNotNull(documentsValue); + assertTrue(documentsValue.isArray()); + + BsonArray documents = documentsValue.asArray(); + assertEquals(3, documents.size()); + + for (int i = 0; i < 3; i++) { + BsonValue doc = documents.get(i); + assertInstanceOf(ByteBufBsonDocument.class, doc); + assertEquals(new BsonInt32(i), doc.asDocument().get("_id")); + assertEquals(new BsonString("doc" + i), doc.asDocument().get("name")); + } + } + } + + @Test + @DisplayName("Sequence field is included in size, keySet, values, and entrySet") + void sequenceFieldIncludedInCollectionViews() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) { + + assertTrue(commandDoc.size() >= 3); + assertTrue(commandDoc.keySet().contains("documents")); + assertTrue(commandDoc.keySet().contains("insert")); + + boolean foundDocumentsArray = false; + for (BsonValue value : commandDoc.values()) { + if (value.isArray() && value.asArray().size() == 2) { + foundDocumentsArray = true; + break; + } + } + assertTrue(foundDocumentsArray); + + boolean foundDocumentsEntry = false; + for (Map.Entry entry : commandDoc.entrySet()) { + if ("documents".equals(entry.getKey())) { + foundDocumentsEntry = true; + assertEquals(2, entry.getValue().asArray().size()); + break; + } + } + assertTrue(foundDocumentsEntry); + } + } + + @Test + @DisplayName("containsKey and containsValue work with sequence fields") + void containsMethodsWorkWithSequenceFields() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 3)) { + + assertTrue(commandDoc.containsKey("documents")); + assertTrue(commandDoc.containsKey("insert")); + assertFalse(commandDoc.containsKey("nonexistent")); + + BsonDocument expectedDoc = new BsonDocument() + .append("_id", new BsonInt32(1)) + .append("name", new BsonString("doc1")); + assertTrue(commandDoc.containsValue(expectedDoc)); + } + } + + @Test + @DisplayName("Sequence field documents are closed when parent is closed") + void sequenceFieldDocumentsClosedWithParent() { + ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2); + + BsonArray documents = commandDoc.getArray("documents"); + List docRefs = new ArrayList<>(); + for (BsonValue doc : documents) { + docRefs.add(doc.asDocument()); + } + + commandDoc.close(); + output.close(); + + assertThrows(IllegalStateException.class, commandDoc::size); + for (BsonDocument doc : docRefs) { + if (doc instanceof ByteBufBsonDocument) { + assertThrows(IllegalStateException.class, doc::size); + } + } + } + + @Test + @DisplayName("Sequence field is cached on multiple access") + void sequenceFieldCached() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) { + + BsonArray first = commandDoc.getArray("documents"); + BsonArray second = commandDoc.getArray("documents"); + assertNotNull(first); + assertEquals(first.size(), second.size()); + } + } + + @Test + @DisplayName("toBsonDocument() hydrates sequence fields to regular BsonDocuments") + void toBsonDocumentHydratesSequenceFields() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) { + + BsonDocument hydrated = commandDoc.toBsonDocument(); + assertTrue(hydrated.containsKey("documents")); + + BsonArray documents = hydrated.getArray("documents"); + assertEquals(2, documents.size()); + for (BsonValue doc : documents) { + assertFalse(doc instanceof ByteBufBsonDocument); + } + } + } + + @Test + @DisplayName("Sequence field with nested documents works correctly") + void sequenceFieldWithNestedDocuments() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider())) { + ByteBufBsonDocument commandDoc = createNestedCommandMessageDocument(output); + + BsonArray documents = commandDoc.getArray("documents"); + assertEquals(2, documents.size()); + + BsonDocument firstDoc = documents.get(0).asDocument(); + BsonDocument nested = firstDoc.getDocument("nested"); + assertEquals(new BsonInt32(0), nested.get("inner")); + + BsonArray array = firstDoc.getArray("array"); + assertEquals(2, array.size()); + + commandDoc.close(); + } + } + + @Test + @DisplayName("Empty sequence field returns empty array") + void emptySequenceField() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 0)) { + + assertTrue(commandDoc.containsKey("insert")); + assertTrue(commandDoc.containsKey("documents")); + assertTrue(commandDoc.getArray("documents").isEmpty()); + } + } + + @Test + @DisplayName("getFirstKey() returns body field, not sequence field") + void getFirstKeyReturnsBodyField() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) { + + assertEquals("insert", commandDoc.getFirstKey()); + } + } + + @Test + @DisplayName("toJson() includes sequence fields") + void toJsonIncludesSequenceFields() { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc = createCommandMessageDocument(output, 2)) { + + String json = commandDoc.toJson(); + assertTrue(json.contains("documents")); + assertTrue(json.contains("_id")); + } + } + + @Test + @DisplayName("equals() and hashCode() include sequence fields") + void equalsAndHashCodeIncludeSequenceFields() { + try (ByteBufferBsonOutput output1 = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc1 = createCommandMessageDocument(output1, 2); + ByteBufferBsonOutput output2 = new ByteBufferBsonOutput(new SimpleBufferProvider()); + ByteBufBsonDocument commandDoc2 = createCommandMessageDocument(output2, 2)) { + + assertEquals(commandDoc1.toBsonDocument(), commandDoc2.toBsonDocument()); + assertEquals(commandDoc1.hashCode(), commandDoc2.hashCode()); + } + } + + // --- Helper Methods --- + + private ByteBufBsonDocument createCommandMessageDocument(final ByteBufferBsonOutput output, final int numDocuments) { + BsonDocument bodyDoc = new BsonDocument() + .append("insert", new BsonString("test")) + .append("$db", new BsonString("db")); + + byte[] bodyBytes = encodeBsonDocument(bodyDoc); + List sequenceDocBytes = new ArrayList<>(); + for (int i = 0; i < numDocuments; i++) { + BsonDocument seqDoc = new BsonDocument() + .append("_id", new BsonInt32(i)) + .append("name", new BsonString("doc" + i)); + sequenceDocBytes.add(encodeBsonDocument(seqDoc)); + } + + writeOpMsgFormat(output, bodyBytes, "documents", sequenceDocBytes); + + List buffers = output.getByteBuffers(); + return ByteBufBsonDocument.createCommandMessage(new CompositeByteBuf(buffers)); + } + + private ByteBufBsonDocument createNestedCommandMessageDocument(final ByteBufferBsonOutput output) { + BsonDocument bodyDoc = new BsonDocument() + .append("insert", new BsonString("test")) + .append("$db", new BsonString("db")); + + byte[] bodyBytes = encodeBsonDocument(bodyDoc); + List sequenceDocBytes = new ArrayList<>(); + for (int i = 0; i < 2; i++) { + BsonDocument seqDoc = new BsonDocument() + .append("_id", new BsonInt32(i)) + .append("nested", new BsonDocument("inner", new BsonInt32(i * 10))) + .append("array", new BsonArray(asList( + new BsonInt32(i), + new BsonDocument("arrayNested", new BsonString("value" + i)) + ))); + sequenceDocBytes.add(encodeBsonDocument(seqDoc)); + } + + writeOpMsgFormat(output, bodyBytes, "documents", sequenceDocBytes); + return ByteBufBsonDocument.createCommandMessage(new CompositeByteBuf(output.getByteBuffers())); + } + + private void writeOpMsgFormat(final ByteBufferBsonOutput output, final byte[] bodyBytes, + final String sequenceIdentifier, final List sequenceDocBytes) { + output.writeBytes(bodyBytes, 0, bodyBytes.length); + + int sequencePayloadSize = sequenceDocBytes.stream().mapToInt(b -> b.length).sum(); + int sequenceSectionSize = 4 + sequenceIdentifier.length() + 1 + sequencePayloadSize; + + output.writeByte(1); + output.writeInt32(sequenceSectionSize); + output.writeCString(sequenceIdentifier); + for (byte[] docBytes : sequenceDocBytes) { + output.writeBytes(docBytes, 0, docBytes.length); + } + } + + private static byte[] encodeBsonDocument(final BsonDocument doc) { + try { + BasicOutputBuffer buffer = new BasicOutputBuffer(); + new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), doc, EncoderContext.builder().build()); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + buffer.pipe(baos); + return baos.toByteArray(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static ByteBuf createByteBufFromDocument(final BsonDocument doc) { + return new ByteBufNIO(ByteBuffer.wrap(encodeBsonDocument(doc))); + } + + private static class SimpleBufferProvider implements BufferProvider { + @NotNull + @Override + public ByteBuf getBuffer(final int size) { + return new ByteBufNIO(ByteBuffer.allocate(size).order(ByteOrder.LITTLE_ENDIAN)); + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java index 8988ea3d6d9..40464d81208 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/ByteBufferBsonOutputTest.java @@ -35,6 +35,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.charset.CharacterCodingException; import java.nio.charset.StandardCharsets; @@ -42,6 +43,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.stream.Stream; @@ -63,7 +65,9 @@ import static java.util.stream.IntStream.rangeClosed; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; final class ByteBufferBsonOutputTest { @@ -495,7 +499,6 @@ void shouldWriteUtf8CString(final boolean useBranch, final BufferProvider buffer @ParameterizedTest(name = "should get byte buffers as little endian. Parameters: useBranch={0}, bufferProvider={1}") @MethodSource("bufferProvidersWithBranches") void shouldGetByteBuffersAsLittleEndian(final boolean useBranch, final BufferProvider bufferProvider) { - List byteBuffers = new ArrayList<>(); try (ByteBufferBsonOutput out = new ByteBufferBsonOutput(bufferProvider)) { byte[] v = {1, 0, 0, 0}; if (useBranch) { @@ -506,9 +509,8 @@ void shouldGetByteBuffersAsLittleEndian(final boolean useBranch, final BufferPro out.writeBytes(v); } - byteBuffers = out.getByteBuffers(); + List byteBuffers = out.getByteBuffers(); assertEquals(1, byteBuffers.get(0).getInt()); - } finally { byteBuffers.forEach(ByteBuf::release); } } @@ -775,7 +777,7 @@ void shouldClose(final boolean useBranch, final BufferProvider bufferProvider) { @ValueSource(ints = {1, INITIAL_BUFFER_SIZE, INITIAL_BUFFER_SIZE * 3}) void shouldHandleMixedBranchingAndTruncating(final int reps) throws CharacterCodingException { BiConsumer write = (out, c) -> { - Assertions.assertTrue((byte) c.charValue() == c); + assertTrue((byte) c.charValue() == c); for (int i = 0; i < reps; i++) { out.writeByte(c); } @@ -840,7 +842,7 @@ void shouldThrowExceptionWhenIntegerDoesNotFitWriteInt32(final BufferProvider bu @MethodSource("bufferProviders") void shouldThrowExceptionWhenAbsolutePositionIsNegative(final BufferProvider bufferProvider) { try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(bufferProvider)) { - Assertions.assertThrows(IllegalArgumentException.class, () -> + assertThrows(IllegalArgumentException.class, () -> output.writeInt32(-1, 5678) ); } @@ -892,7 +894,7 @@ void shouldWriteInt32AbsoluteValueWithinSpanningBuffers( final List expectedBuffers, final BufferProvider bufferProvider) { - List buffers = new ArrayList<>(); + List buffers; try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(size -> bufferProvider.getBuffer(Integer.BYTES))) { //given @@ -905,7 +907,6 @@ void shouldWriteInt32AbsoluteValueWithinSpanningBuffers( buffers = output.getByteBuffers(); assertEquals(expectedBuffers.size(), buffers.size(), "Number of buffers mismatch"); assertBufferContents(expectedBuffers, buffers); - } finally { buffers.forEach(ByteBuf::release); } } @@ -1022,9 +1023,8 @@ void shouldWriteInt32WithinSpanningBuffers( final int expectedLastBufferPosition, final BufferProvider bufferProvider) { - List buffers = new ArrayList<>(); - try (ByteBufferBsonOutput output = - new ByteBufferBsonOutput(size -> bufferProvider.getBuffer(Integer.BYTES))) { + List buffers; + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(size -> bufferProvider.getBuffer(Integer.BYTES))) { //given initialData.forEach(output::writeBytes); @@ -1033,14 +1033,9 @@ void shouldWriteInt32WithinSpanningBuffers( output.writeInt32(intValue); //then - //getByteBuffers returns ByteBuffers with limit() set to position, position set to 0. buffers = output.getByteBuffers(); assertEquals(expectedBuffers.size(), buffers.size(), "Number of buffers mismatch"); assertBufferContents(expectedBuffers, buffers); - - assertEquals(expectedLastBufferPosition, buffers.get(buffers.size() - 1).limit()); - assertEquals(expectedOutputPosition, output.getPosition()); - } finally { buffers.forEach(ByteBuf::release); } } @@ -1057,9 +1052,8 @@ void shouldWriteInt64WithinSpanningBuffers( final int expectedLastBufferPosition, final BufferProvider bufferProvider) { - List buffers = new ArrayList<>(); - try (ByteBufferBsonOutput output = - new ByteBufferBsonOutput(size -> bufferProvider.getBuffer(Long.BYTES))) { + List buffers; + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(size -> bufferProvider.getBuffer(Long.BYTES))) { //given initialData.forEach(output::writeBytes); @@ -1075,7 +1069,6 @@ void shouldWriteInt64WithinSpanningBuffers( assertEquals(expectedLastBufferPosition, buffers.get(buffers.size() - 1).limit()); assertEquals(expectedOutputPosition, output.getPosition()); - } finally { buffers.forEach(ByteBuf::release); } } @@ -1092,7 +1085,6 @@ void shouldWriteDoubleWithinSpanningBuffers( final int expectedLastBufferPosition, final BufferProvider bufferProvider) { - List buffers = new ArrayList<>(); try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(size -> bufferProvider.getBuffer(Long.BYTES))) { @@ -1104,13 +1096,12 @@ void shouldWriteDoubleWithinSpanningBuffers( //then //getByteBuffers returns ByteBuffers with limit() set to position, position set to 0. - buffers = output.getByteBuffers(); + List buffers = output.getByteBuffers(); assertEquals(expectedBuffers.size(), buffers.size(), "Number of buffers mismatch"); assertBufferContents(expectedBuffers, buffers); assertEquals(expectedLastBufferPosition, buffers.get(buffers.size() - 1).limit()); assertEquals(expectedOutputPosition, output.getPosition()); - } finally { buffers.forEach(ByteBuf::release); } } @@ -1422,7 +1413,6 @@ private void testWriteCStringAcrossBuffers(final BufferProvider bufferProvider, bufferAllocationSize, actualByteBuffers, actualFlattenedByteBuffersBytes); - } finally { actualByteBuffers.forEach(ByteBuf::release); } } @@ -1435,8 +1425,6 @@ private void testWriteStringAcrossBuffers(final BufferProvider bufferProvider, final byte[] expectedEncoding) throws IOException { for (int startingOffset = 0; startingOffset <= bufferAllocationSize; startingOffset++) { //given - List actualByteBuffers = emptyList(); - try (ByteBufferBsonOutput actualBsonOutput = new ByteBufferBsonOutput( size -> bufferProvider.getBuffer(bufferAllocationSize))) { // Write an initial startingOffset of empty bytes to shift the start position @@ -1446,7 +1434,7 @@ private void testWriteStringAcrossBuffers(final BufferProvider bufferProvider, actualBsonOutput.writeString(stringToEncode); // then - actualByteBuffers = actualBsonOutput.getDuplicateByteBuffers(); + List actualByteBuffers = actualBsonOutput.getDuplicateByteBuffers(); byte[] actualFlattenedByteBuffersBytes = getBytes(actualBsonOutput); assertEncodedStringSize(codePoint, @@ -1459,7 +1447,7 @@ private void testWriteStringAcrossBuffers(final BufferProvider bufferProvider, bufferAllocationSize, actualByteBuffers, actualFlattenedByteBuffersBytes); - } finally { + actualByteBuffers.forEach(ByteBuf::release); } } @@ -1499,6 +1487,7 @@ private void testWriteStringAcrossBuffersWithBranch(final BufferProvider bufferP bufferAllocationSize, actualBranchByteBuffers, actualFlattenedByteBuffersBytes); + actualBranchByteBuffers.forEach(ByteBuf::release); } // then @@ -1515,10 +1504,7 @@ private void testWriteStringAcrossBuffersWithBranch(final BufferProvider bufferP bufferAllocationSize, actualByteBuffers, actualFlattenedByteBuffersBytes); - - } finally { actualByteBuffers.forEach(ByteBuf::release); - actualBranchByteBuffers.forEach(ByteBuf::release); } } } @@ -1583,7 +1569,6 @@ private void testWriteCStringAcrossBufferWithBranch(final BufferProvider bufferP bufferAllocationSize, actualByteBuffers, actualFlattenedByteBuffersBytes); - } finally { actualByteBuffers.forEach(ByteBuf::release); actualBranchByteBuffers.forEach(ByteBuf::release); } @@ -1652,6 +1637,330 @@ public char[] toSurrogatePair(final int codePoint) { } + @Nested + @DisplayName("ByteBuf Leak Detection and Release Tests") + class LeakDetectionTests { + + @DisplayName("should release all buffers when close is called") + @ParameterizedTest(name = "should release all buffers when close is called. BufferProvider={0}") + @MethodSource("com.mongodb.internal.connection.ByteBufferBsonOutputTest#bufferProviders") + void shouldReleaseAllBuffersOnClose(final BufferProvider bufferProvider) { + List allocatedBuffers = new ArrayList<>(); + BufferProvider trackingProvider = size -> { + ByteBuf buf = bufferProvider.getBuffer(size); + allocatedBuffers.add(buf); + return buf; + }; + + ByteBufferBsonOutput output = new ByteBufferBsonOutput(trackingProvider); + // Write enough to allocate multiple buffers + output.writeBytes(new byte[INITIAL_BUFFER_SIZE * 3]); + + // Verify buffers are allocated and have positive reference count + assertFalse(allocatedBuffers.isEmpty(), "Should have allocated buffers"); + for (ByteBuf buf : allocatedBuffers) { + assertTrue(buf.getReferenceCount() > 0, "Buffer should have positive ref count before close"); + } + + output.close(); + + // Verify all buffers are released + for (ByteBuf buf : allocatedBuffers) { + assertEquals(0, buf.getReferenceCount(), "Buffer should be released after close"); + } + } + + @DisplayName("should release all buffers when truncateToPosition removes buffers") + @ParameterizedTest(name = "should release all buffers when truncateToPosition removes buffers. BufferProvider={0}") + @MethodSource("com.mongodb.internal.connection.ByteBufferBsonOutputTest#bufferProviders") + void shouldReleaseBuffersOnTruncate(final BufferProvider bufferProvider) { + List allocatedBuffers = new ArrayList<>(); + BufferProvider trackingProvider = size -> { + ByteBuf buf = bufferProvider.getBuffer(size); + allocatedBuffers.add(buf); + return buf; + }; + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(trackingProvider)) { + // Write enough to allocate multiple buffers + output.writeBytes(new byte[INITIAL_BUFFER_SIZE * 4]); + + int buffersBeforeTruncate = allocatedBuffers.size(); + assertTrue(buffersBeforeTruncate > 1, "Should have multiple buffers"); + + // Truncate to first buffer only + output.truncateToPosition(INITIAL_BUFFER_SIZE / 2); + + // Verify truncated buffers are released (all but first should be released) + for (int i = 1; i < allocatedBuffers.size(); i++) { + assertEquals(0, allocatedBuffers.get(i).getReferenceCount(), + "Truncated buffer " + i + " should be released"); + } + + // First buffer should still be retained + assertTrue(allocatedBuffers.get(0).getReferenceCount() > 0, + "First buffer should still be retained"); + } + } + + @DisplayName("caller must release buffers returned by getByteBuffers") + @ParameterizedTest(name = "caller must release buffers returned by getByteBuffers. BufferProvider={0}") + @MethodSource("com.mongodb.internal.connection.ByteBufferBsonOutputTest#bufferProviders") + void callerMustReleaseGetByteBuffersResult(final BufferProvider bufferProvider) { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(bufferProvider)) { + output.writeBytes(new byte[INITIAL_BUFFER_SIZE * 2]); + + List returnedBuffers = output.getByteBuffers(); + assertFalse(returnedBuffers.isEmpty()); + + // Verify returned buffers have reference count of 1 (from duplicate) + for (ByteBuf buf : returnedBuffers) { + assertTrue(buf.getReferenceCount() > 0, + "Returned buffer should have positive ref count"); + } + + // Caller releases the buffers + returnedBuffers.forEach(ByteBuf::release); + + // Verify released + for (ByteBuf buf : returnedBuffers) { + assertEquals(0, buf.getReferenceCount(), + "Returned buffer should be released after caller releases it"); + } + + // Output should still be usable and internal buffers should still be valid + output.writeByte(1); + assertEquals(INITIAL_BUFFER_SIZE * 2 + 1, output.getPosition()); + } + } + + @DisplayName("caller must release buffers returned by getDuplicateByteBuffers") + @ParameterizedTest(name = "caller must release buffers returned by getDuplicateByteBuffers. BufferProvider={0}") + @MethodSource("com.mongodb.internal.connection.ByteBufferBsonOutputTest#bufferProviders") + void callerMustReleaseDuplicateByteBuffersResult(final BufferProvider bufferProvider) { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(bufferProvider)) { + output.writeBytes(new byte[INITIAL_BUFFER_SIZE * 2]); + + List returnedBuffers = output.getDuplicateByteBuffers(); + assertFalse(returnedBuffers.isEmpty()); + + // Caller releases the buffers + returnedBuffers.forEach(ByteBuf::release); + + // Verify released + for (ByteBuf buf : returnedBuffers) { + assertEquals(0, buf.getReferenceCount(), + "Returned buffer should be released after caller releases it"); + } + } + } + + @DisplayName("pipe should release temporary buffers even on exception") + @Test + void pipeShouldReleaseBuffersOnException() { + List allocatedBuffers = new ArrayList<>(); + BufferProvider trackingProvider = size -> { + ByteBuf buf = new SimpleBufferProvider().getBuffer(size); + allocatedBuffers.add(buf); + return buf; + }; + + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(trackingProvider)) { + output.writeBytes(new byte[INITIAL_BUFFER_SIZE]); + + // Create an OutputStream that throws + OutputStream failingStream = new OutputStream() { + @Override + public void write(final int b) throws IOException { + throw new IOException("Simulated failure"); + } + + @Override + public void write(final byte[] b, final int off, final int len) throws IOException { + throw new IOException("Simulated failure"); + } + }; + + assertThrows(IOException.class, () -> output.pipe(failingStream)); + + // Output should still be valid and usable + assertTrue(output.isOpen()); + } + + // Verify all buffers are released after close + for (ByteBuf buf : allocatedBuffers) { + assertEquals(0, buf.getReferenceCount(), "Buffer should be released after close"); + } + } + + @DisplayName("branch should transfer buffer ownership to parent on close") + @ParameterizedTest(name = "branch should transfer buffer ownership to parent on close. BufferProvider={0}") + @MethodSource("com.mongodb.internal.connection.ByteBufferBsonOutputTest#bufferProviders") + void branchShouldTransferOwnershipToParent(final BufferProvider bufferProvider) { + AtomicInteger parentBufferCount = new AtomicInteger(); + AtomicInteger branchBufferCount = new AtomicInteger(); + List allBuffers = new ArrayList<>(); + + try (ByteBufferBsonOutput parent = new ByteBufferBsonOutput(size -> { + ByteBuf buf = bufferProvider.getBuffer(size); + parentBufferCount.incrementAndGet(); + allBuffers.add(buf); + return buf; + })) { + parent.writeBytes(new byte[100]); + int parentBuffersAllocated = parentBufferCount.get(); + + try (ByteBufferBsonOutput.Branch branch = parent.branch()) { + // Use a separate tracking for branch allocations + branch.writeBytes(new byte[100]); + branchBufferCount.set(allBuffers.size() - parentBuffersAllocated); + } + + // After branch close, parent should have merged the branch buffers + // All buffers should still be retained by parent + for (ByteBuf buf : allBuffers) { + assertTrue(buf.getReferenceCount() > 0, + "Buffer should be retained by parent after branch merge"); + } + } + + // After parent close, all buffers should be released + for (ByteBuf buf : allBuffers) { + assertEquals(0, buf.getReferenceCount(), "Buffer should be released after parent close"); + } + } + + @DisplayName("should not leak buffers when exception occurs during write operations") + @Test + void shouldNotLeakBuffersOnWriteException() { + List allocatedBuffers = new ArrayList<>(); + AtomicInteger allocationCount = new AtomicInteger(); + + // Create a provider that fails after allocating some buffers + BufferProvider failingProvider = size -> { + if (allocationCount.incrementAndGet() > 3) { + throw new RuntimeException("Simulated allocation failure"); + } + ByteBuf buf = new SimpleBufferProvider().getBuffer(size); + allocatedBuffers.add(buf); + return buf; + }; + + ByteBufferBsonOutput output = new ByteBufferBsonOutput(failingProvider); + try { + // Write enough to trigger multiple allocations and eventually fail + for (int i = 0; i < 100; i++) { + output.writeBytes(new byte[INITIAL_BUFFER_SIZE]); + } + Assertions.fail("Should have thrown exception"); + } catch (RuntimeException e) { + assertEquals("Simulated allocation failure", e.getMessage()); + } finally { + output.close(); + } + + // Verify all allocated buffers are released + for (ByteBuf buf : allocatedBuffers) { + assertEquals(0, buf.getReferenceCount(), "Buffer should be released after close"); + } + } + + @DisplayName("multiple calls to close should be idempotent") + @ParameterizedTest(name = "multiple calls to close should be idempotent. BufferProvider={0}") + @MethodSource("com.mongodb.internal.connection.ByteBufferBsonOutputTest#bufferProviders") + void closeShouldBeIdempotent(final BufferProvider bufferProvider) { + List allocatedBuffers = new ArrayList<>(); + BufferProvider trackingProvider = size -> { + ByteBuf buf = bufferProvider.getBuffer(size); + allocatedBuffers.add(buf); + return buf; + }; + + ByteBufferBsonOutput output = new ByteBufferBsonOutput(trackingProvider); + output.writeBytes(new byte[INITIAL_BUFFER_SIZE]); + + output.close(); + // Should not throw + output.close(); + output.close(); + + // Buffers should still be released + for (ByteBuf buf : allocatedBuffers) { + assertEquals(0, buf.getReferenceCount(), "Buffer should be released"); + } + } + + @DisplayName("branch close should be idempotent") + @ParameterizedTest(name = "branch close should be idempotent. BufferProvider={0}") + @MethodSource("com.mongodb.internal.connection.ByteBufferBsonOutputTest#bufferProviders") + void branchCloseShouldBeIdempotent(final BufferProvider bufferProvider) { + try (ByteBufferBsonOutput parent = new ByteBufferBsonOutput(bufferProvider)) { + ByteBufferBsonOutput.Branch branch = parent.branch(); + branch.writeBytes(new byte[100]); + + branch.close(); + // Should not throw + branch.close(); + branch.close(); + + // Parent should have the data + assertEquals(100, parent.getPosition()); + } + } + + @DisplayName("getByteBuffers called multiple times returns independent duplicates") + @ParameterizedTest(name = "getByteBuffers called multiple times returns independent duplicates. BufferProvider={0}") + @MethodSource("com.mongodb.internal.connection.ByteBufferBsonOutputTest#bufferProviders") + void getByteBuffersReturnsIndependentDuplicates(final BufferProvider bufferProvider) { + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(bufferProvider)) { + output.writeBytes(new byte[100]); + + List buffers1 = output.getByteBuffers(); + List buffers2 = output.getByteBuffers(); + + // Both should be valid + assertFalse(buffers1.isEmpty()); + assertFalse(buffers2.isEmpty()); + + // Release first set + buffers1.forEach(ByteBuf::release); + + // Second set should still be valid + for (ByteBuf buf : buffers2) { + assertTrue(buf.getReferenceCount() > 0, + "Second set of buffers should still be valid"); + } + + // Release second set + buffers2.forEach(ByteBuf::release); + } + } + + @DisplayName("unreleased getByteBuffers result should not prevent output close") + @ParameterizedTest(name = "unreleased getByteBuffers result should not prevent output close. BufferProvider={0}") + @MethodSource("com.mongodb.internal.connection.ByteBufferBsonOutputTest#bufferProviders") + void unreleasedGetByteBuffersShouldNotPreventClose(final BufferProvider bufferProvider) { + List leakedBuffers; + try (ByteBufferBsonOutput output = new ByteBufferBsonOutput(bufferProvider)) { + output.writeBytes(new byte[100]); + leakedBuffers = output.getByteBuffers(); + // Intentionally NOT releasing leakedBuffers to simulate a leak + } + + // The output closed successfully, but the leaked buffers still have references + // This simulates what happens when callers forget to release + for (ByteBuf buf : leakedBuffers) { + // Note: These buffers are duplicates, so they have their own reference count + // The internal buffers were released, but these duplicates were not + assertTrue(buf.getReferenceCount() >= 0, + "Leaked duplicate buffers may still have references"); + } + + // Clean up the leaked buffers + leakedBuffers.forEach(ByteBuf::release); + } + } + private static byte[] getBytes(final OutputBuffer basicOutputBuffer) throws IOException { ByteArrayOutputStream baos = new ByteArrayOutputStream(basicOutputBuffer.getSize()); basicOutputBuffer.pipe(baos); diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy index 77bdd5e2045..a306d1e29a9 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageSpecification.groovy @@ -159,16 +159,18 @@ class CommandMessageSpecification extends Specification { def output = new ByteBufferBsonOutput(new SimpleBufferProvider()) message.encode(output, new OperationContext(IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, Stub(TimeoutContext), null)) + def commandMessageBsonDocument = message.getCommandDocument(output) when: - def commandDocument = message.getCommandDocument(output) - def expectedCommandDocument = new BsonDocument('insert', new BsonString('coll')).append('documents', new BsonArray([new BsonDocument('_id', new BsonInt32(1)), new BsonDocument('_id', new BsonInt32(2))])) expectedCommandDocument.append('$db', new BsonString(namespace.getDatabaseName())) + then: - commandDocument == expectedCommandDocument + commandMessageBsonDocument == expectedCommandDocument + cleanup: + commandMessageBsonDocument?.close() where: [maxWireVersion, originalCommandDocument, payload] << [ diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java index 091518c715c..0bad2109974 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/CommandMessageTest.java @@ -164,8 +164,10 @@ void getCommandDocumentFromClientBulkWrite() { new OperationContext( IgnorableRequestContext.INSTANCE, NoOpSessionContext.INSTANCE, new TimeoutContext(TimeoutSettings.DEFAULT), null)); - BsonDocument actualCommandDocument = commandMessage.getCommandDocument(output); - assertEquals(expectedCommandDocument, actualCommandDocument); + + try (ByteBufBsonDocument actualCommandDocument = commandMessage.getCommandDocument(output)) { + assertEquals(expectedCommandDocument, actualCommandDocument); + } } } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java index 30a70042578..141fb76d818 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/CompositeByteBufTest.java @@ -16,6 +16,9 @@ package com.mongodb.internal.connection; import com.mongodb.internal.connection.netty.NettyByteBuf; +import com.mongodb.lang.NonNull; +import io.netty.buffer.UnpooledByteBufAllocator; +import org.bson.BsonBinaryWriter; import org.bson.ByteBuf; import org.bson.ByteBufNIO; import org.junit.jupiter.api.DisplayName; @@ -27,6 +30,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.ArrayList; import java.util.List; import java.util.stream.Stream; @@ -39,25 +43,59 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; - +@DisplayName("CompositeByteBuf") final class CompositeByteBufTest { + // Construction tests + @Test + @DisplayName("Construction: should throw IllegalArgumentException when buffers is null") @SuppressWarnings("ConstantConditions") - void shouldThrowIfBuffersIsNull() { + void constructorShouldThrowIfBuffersIsNull() { assertThrows(IllegalArgumentException.class, () -> new CompositeByteBuf((List) null)); } @Test - void shouldThrowIfBuffersIsEmpty() { + @DisplayName("Construction: should throw IllegalArgumentException when buffers is empty") + void constructorShouldThrowIfBuffersIsEmpty() { assertThrows(IllegalArgumentException.class, () -> new CompositeByteBuf(emptyList())); } - @DisplayName("referenceCount should be maintained") - @ParameterizedTest + @Test + @DisplayName("Construction: should calculate capacity as sum of all buffer limits") + void constructorShouldCalculateCapacityAsSumOfBufferLimits() { + assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).capacity()); + assertEquals(6, new CompositeByteBuf(asList( + new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), + new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})) + )).capacity()); + } + + @Test + @DisplayName("Construction: should initialize position to zero") + void constructorShouldInitializePositionToZero() { + assertEquals(0, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).position()); + } + + @Test + @DisplayName("Construction: should initialize limit as sum of all buffer limits") + void constructorShouldInitializeLimitAsSumOfBufferLimits() { + assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).limit()); + assertEquals(6, new CompositeByteBuf(asList( + new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), + new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})) + )).limit()); + } + + // Reference counting tests + + @DisplayName("Reference counting: should maintain reference count correctly") + @ParameterizedTest(name = "with {0}") @MethodSource("getBuffers") void referenceCountShouldBeMaintained(final List buffers) { CompositeByteBuf buf = new CompositeByteBuf(buffers); @@ -76,108 +114,198 @@ void referenceCountShouldBeMaintained(final List buffers) { assertThrows(IllegalStateException.class, buf::retain); } - private static Stream getBuffers() { - return Stream.of( - Arguments.of(Named.of("ByteBufNIO", - asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), - new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))))), - Arguments.of(Named.of("NettyByteBuf", - asList(new NettyByteBuf(copiedBuffer(new byte[]{1, 2, 3, 4})), - new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4}))))), - Arguments.of(Named.of("Mixed NIO and NettyByteBuf", - asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), - new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4}))))) - ); + @ParameterizedTest(name = "with {0}") + @DisplayName("Reference counting: should release underlying buffers when reference count reaches zero") + @MethodSource("bufferProviders") + void releaseShouldReleaseUnderlyingBuffers(final String description, final TrackingBufferProvider bufferProvider) { + List buffers = asList(bufferProvider.getBuffer(1), bufferProvider.getBuffer(1)); + CompositeByteBuf compositeByteBuf = new CompositeByteBuf(buffers); + + assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() > 0)); + bufferProvider.assertAllAvailable(); + + compositeByteBuf.release(); + buffers.forEach(ByteBuf::release); + + assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() == 0)); + bufferProvider.assertAllUnavailable(); + } + + @ParameterizedTest(name = "with {0}") + @DisplayName("Reference counting: duplicate should have independent reference count from original") + @MethodSource("bufferProviders") + void duplicateShouldHaveIndependentReferenceCount(final String description, final TrackingBufferProvider bufferProvider) { + List buffers = asList(bufferProvider.getBuffer(1), bufferProvider.getBuffer(1)); + CompositeByteBuf compositeBuffer = new CompositeByteBuf(buffers); + assertEquals(1, compositeBuffer.getReferenceCount()); + + ByteBuf compositeBufferDuplicate = compositeBuffer.duplicate(); + assertEquals(1, compositeBufferDuplicate.getReferenceCount()); + assertEquals(1, compositeBuffer.getReferenceCount()); + + compositeBuffer.release(); + assertEquals(0, compositeBuffer.getReferenceCount()); + assertEquals(1, compositeBufferDuplicate.getReferenceCount()); + + compositeBufferDuplicate.release(); + assertEquals(0, compositeBuffer.getReferenceCount()); + assertEquals(0, compositeBufferDuplicate.getReferenceCount()); + + bufferProvider.assertAllAvailable(); + buffers.forEach(ByteBuf::release); + + assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() == 0)); + bufferProvider.assertAllUnavailable(); + } + + @ParameterizedTest(name = "with {0}") + @DisplayName("Reference counting: should work correctly with BsonBinaryWriter") + @MethodSource("bufferProviders") + void shouldWorkCorrectlyWithBsonBinaryWriter(final String description, final TrackingBufferProvider bufferProvider) { + List buffers; + + try (ByteBufferBsonOutput bufferBsonOutput = new ByteBufferBsonOutput(bufferProvider)) { + try (BsonBinaryWriter bsonBinaryWriter = new BsonBinaryWriter(bufferBsonOutput)) { + bsonBinaryWriter.writeStartDocument(); + bsonBinaryWriter.writeName("k"); + bsonBinaryWriter.writeInt32(42); + bsonBinaryWriter.writeEndDocument(); + bsonBinaryWriter.flush(); + } + buffers = bufferBsonOutput.getByteBuffers(); + assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() > 0)); + + CompositeByteBuf compositeBuffer = new CompositeByteBuf(buffers); + assertEquals(1, compositeBuffer.getReferenceCount()); + + ByteBuf compositeBufferDuplicate = compositeBuffer.duplicate(); + assertEquals(1, compositeBufferDuplicate.getReferenceCount()); + assertEquals(1, compositeBuffer.getReferenceCount()); + + compositeBuffer.release(); + assertEquals(0, compositeBuffer.getReferenceCount()); + assertEquals(1, compositeBufferDuplicate.getReferenceCount()); + + compositeBufferDuplicate.release(); + assertEquals(0, compositeBuffer.getReferenceCount()); + assertEquals(0, compositeBufferDuplicate.getReferenceCount()); + + bufferProvider.assertAllAvailable(); + buffers.forEach(ByteBuf::release); + } + + assertTrue(buffers.stream().allMatch(buffer -> buffer.getReferenceCount() == 0)); + bufferProvider.assertAllUnavailable(); } @Test + @DisplayName("Reference counting: should throw IllegalStateException when accessing released buffer") + void shouldThrowIllegalStateExceptionIfBufferIsClosed() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); + buf.release(); + + assertThrows(IllegalStateException.class, buf::get); + } + + // Byte order tests + + @Test + @DisplayName("Byte order: should throw UnsupportedOperationException for BIG_ENDIAN byte order") void orderShouldThrowIfNotLittleEndian() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); assertThrows(UnsupportedOperationException.class, () -> buf.order(ByteOrder.BIG_ENDIAN)); } @Test + @DisplayName("Byte order: should accept LITTLE_ENDIAN byte order") void orderShouldReturnNormallyIfLittleEndian() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); assertDoesNotThrow(() -> buf.order(ByteOrder.LITTLE_ENDIAN)); } + // Position tests + @Test - void limitShouldBeSumOfLimitsOfBuffers() { - assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).limit()); + @DisplayName("Position: should set position when within valid range") + void positionShouldBeSetIfInRange() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); - assertEquals(6, new CompositeByteBuf(asList( - new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), - new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})) - )).limit()); + for (int i = 0; i <= 3; i++) { + buf.position(i); + assertEquals(i, buf.position()); + } } @Test - void capacityShouldBeTheInitialLimit() { - assertEquals(4, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).capacity()); - assertEquals(6, new CompositeByteBuf(asList( - new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), - new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})) - )).capacity()); + @DisplayName("Position: should throw IndexOutOfBoundsException when position is negative") + void positionShouldThrowForNegativeValue() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); + assertThrows(IndexOutOfBoundsException.class, () -> buf.position(-1)); } @Test - void positionShouldBeZero() { - assertEquals(0, new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))).position()); + @DisplayName("Position: should throw IndexOutOfBoundsException when position exceeds capacity") + void positionShouldThrowWhenExceedsCapacity() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); + assertThrows(IndexOutOfBoundsException.class, () -> buf.position(4)); } @Test - void positionShouldBeSetIfInRange() { + @DisplayName("Position: should throw IndexOutOfBoundsException when position exceeds limit") + void positionShouldThrowWhenExceedsLimit() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); - buf.position(0); - assertEquals(0, buf.position()); - - buf.position(1); - assertEquals(1, buf.position()); - - buf.position(2); - assertEquals(2, buf.position()); - - buf.position(3); - assertEquals(3, buf.position()); + buf.limit(2); + assertThrows(IndexOutOfBoundsException.class, () -> buf.position(3)); } @Test - void positionShouldThrowIfOutOfRange() { - CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); + @DisplayName("Position: should update remaining bytes as position changes during reads") + void positionRemainingAndHasRemainingShouldUpdateAsBytesAreRead() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); - assertThrows(IndexOutOfBoundsException.class, () -> buf.position(-1)); - assertThrows(IndexOutOfBoundsException.class, () -> buf.position(4)); + for (int i = 0; i < 4; i++) { + assertEquals(i, buf.position()); + assertEquals(4 - i, buf.remaining()); + assertTrue(buf.hasRemaining()); + buf.get(); + } - buf.limit(2); - assertThrows(IndexOutOfBoundsException.class, () -> buf.position(3)); + assertEquals(4, buf.position()); + assertEquals(0, buf.remaining()); + assertFalse(buf.hasRemaining()); } + // Limit tests + @Test + @DisplayName("Limit: should set limit when within valid range") void limitShouldBeSetIfInRange() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); - buf.limit(0); - assertEquals(0, buf.limit()); - buf.limit(1); - assertEquals(1, buf.limit()); - - buf.limit(2); - assertEquals(2, buf.limit()); - - buf.limit(3); - assertEquals(3, buf.limit()); + for (int i = 0; i <= 3; i++) { + buf.limit(i); + assertEquals(i, buf.limit()); + } } @Test - void limitShouldThrowIfOutOfRange() { + @DisplayName("Limit: should throw IndexOutOfBoundsException when limit is negative") + void limitShouldThrowForNegativeValue() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); - assertThrows(IndexOutOfBoundsException.class, () -> buf.limit(-1)); + } + + @Test + @DisplayName("Limit: should throw IndexOutOfBoundsException when limit exceeds capacity") + void limitShouldThrowWhenExceedsCapacity() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); assertThrows(IndexOutOfBoundsException.class, () -> buf.limit(4)); } + // Clear tests + @Test + @DisplayName("Clear: should reset position to zero and limit to capacity") void clearShouldResetPositionAndLimit() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3})))); buf.limit(2); @@ -188,7 +316,10 @@ void clearShouldResetPositionAndLimit() { assertEquals(3, buf.limit()); } + // Duplicate tests + @Test + @DisplayName("Duplicate: should copy position and limit to duplicate") void duplicateShouldCopyAllProperties() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 1, 2, 3, 4, 1, 2})))); buf.limit(6); @@ -203,35 +334,32 @@ void duplicateShouldCopyAllProperties() { assertEquals(2, buf.position()); } + // Get byte tests + @Test - void positionRemainingAndHasRemainingShouldUpdateAsBytesAreRead() { + @DisplayName("Get byte: relative get should read byte and move position") + void relativeGetShouldReadByteAndMovePosition() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); - assertEquals(0, buf.position()); - assertEquals(4, buf.remaining()); - assertTrue(buf.hasRemaining()); - buf.get(); + assertEquals(1, buf.get()); assertEquals(1, buf.position()); - assertEquals(3, buf.remaining()); - assertTrue(buf.hasRemaining()); - - buf.get(); + assertEquals(2, buf.get()); assertEquals(2, buf.position()); - assertEquals(2, buf.remaining()); - assertTrue(buf.hasRemaining()); + } - buf.get(); - assertEquals(3, buf.position()); - assertEquals(1, buf.remaining()); - assertTrue(buf.hasRemaining()); + @Test + @DisplayName("Get byte: should throw IndexOutOfBoundsException when reading past limit") + void getShouldThrowWhenReadingPastLimit() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); + buf.position(4); - buf.get(); - assertEquals(4, buf.position()); - assertEquals(0, buf.remaining()); - assertFalse(buf.hasRemaining()); + assertThrows(IndexOutOfBoundsException.class, buf::get); } + // Get int tests + @Test + @DisplayName("Get int: absolute getInt should read little-endian integer and preserve position") void absoluteGetIntShouldReadLittleEndianIntegerAndPreservePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -243,6 +371,7 @@ void absoluteGetIntShouldReadLittleEndianIntegerAndPreservePosition() { } @Test + @DisplayName("Get int: absolute getInt should read correctly when integer is split across buffers") void absoluteGetIntShouldReadLittleEndianIntegerWhenIntegerIsSplitAcrossBuffers() { ByteBufNIO byteBufferOne = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})); ByteBuf byteBufferTwo = new NettyByteBuf(wrappedBuffer(new byte[]{3, 4})).flip(); @@ -256,16 +385,30 @@ void absoluteGetIntShouldReadLittleEndianIntegerWhenIntegerIsSplitAcrossBuffers( } @Test + @DisplayName("Get int: relative getInt should read little-endian integer and move position") void relativeGetIntShouldReadLittleEndianIntegerAndMovePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); int i = buf.getInt(); + assertEquals(67305985, i); assertEquals(4, buf.position()); assertEquals(0, byteBuffer.position()); } @Test + @DisplayName("Get int: should throw IndexOutOfBoundsException when not enough bytes for int") + void getIntShouldThrowWhenNotEnoughBytesForInt() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); + buf.position(1); + + assertThrows(IndexOutOfBoundsException.class, buf::getInt); + } + + // Get long tests + + @Test + @DisplayName("Get long: absolute getLong should read little-endian long and preserve position") void absoluteGetLongShouldReadLittleEndianLongAndPreservePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -277,7 +420,8 @@ void absoluteGetLongShouldReadLittleEndianLongAndPreservePosition() { } @Test - void absoluteGetLongShouldReadLittleEndianLongWhenDoubleIsSplitAcrossBuffers() { + @DisplayName("Get long: absolute getLong should read correctly when long is split across multiple buffers") + void absoluteGetLongShouldReadLittleEndianLongWhenSplitAcrossBuffers() { ByteBuf byteBufferOne = new NettyByteBuf(wrappedBuffer(new byte[]{1, 2})).flip(); ByteBuf byteBufferTwo = new ByteBufNIO(ByteBuffer.wrap(new byte[]{3, 4})); ByteBuf byteBufferThree = new NettyByteBuf(wrappedBuffer(new byte[]{5, 6})).flip(); @@ -287,11 +431,10 @@ void absoluteGetLongShouldReadLittleEndianLongWhenDoubleIsSplitAcrossBuffers() { assertEquals(578437695752307201L, l); assertEquals(0, buf.position()); - assertEquals(0, byteBufferOne.position()); - assertEquals(0, byteBufferTwo.position()); } @Test + @DisplayName("Get long: relative getLong should read little-endian long and move position") void relativeGetLongShouldReadLittleEndianLongAndMovePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -302,7 +445,10 @@ void relativeGetLongShouldReadLittleEndianLongAndMovePosition() { assertEquals(0, byteBuffer.position()); } + // Get double tests + @Test + @DisplayName("Get double: absolute getDouble should read little-endian double and preserve position") void absoluteGetDoubleShouldReadLittleEndianDoubleAndPreservePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -314,6 +460,7 @@ void absoluteGetDoubleShouldReadLittleEndianDoubleAndPreservePosition() { } @Test + @DisplayName("Get double: relative getDouble should read little-endian double and move position") void relativeGetDoubleShouldReadLittleEndianDoubleAndMovePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -324,7 +471,10 @@ void relativeGetDoubleShouldReadLittleEndianDoubleAndMovePosition() { assertEquals(0, byteBuffer.position()); } + // Bulk get tests + @Test + @DisplayName("Bulk get: absolute bulk get should read bytes and preserve position") void absoluteBulkGetShouldReadBytesAndPreservePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -337,6 +487,7 @@ void absoluteBulkGetShouldReadBytesAndPreservePosition() { } @Test + @DisplayName("Bulk get: absolute bulk get should read bytes split across multiple buffers") void absoluteBulkGetShouldReadBytesWhenSplitAcrossBuffers() { ByteBufNIO byteBufferOne = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1})); ByteBufNIO byteBufferTwo = new ByteBufNIO(ByteBuffer.wrap(new byte[]{2, 3})); @@ -354,6 +505,7 @@ void absoluteBulkGetShouldReadBytesWhenSplitAcrossBuffers() { } @Test + @DisplayName("Bulk get: relative bulk get should read bytes and move position") void relativeBulkGetShouldReadBytesAndMovePosition() { ByteBufNIO byteBuffer = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})); CompositeByteBuf buf = new CompositeByteBuf(singletonList(byteBuffer)); @@ -372,6 +524,16 @@ void relativeBulkGetShouldReadBytesAndMovePosition() { } @Test + @DisplayName("Bulk get: should throw IndexOutOfBoundsException when bulk get exceeds remaining") + void bulkGetShouldThrowWhenBulkGetExceedsRemaining() { + CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); + assertThrows(IndexOutOfBoundsException.class, () -> buf.get(new byte[2], 1, 5)); + } + + // asNIO tests + + @Test + @DisplayName("asNIO: should get as NIO ByteBuffer with correct position and limit") void shouldGetAsNIOByteBuffer() { CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4, 5, 6, 7, 8})))); buf.position(1).limit(5); @@ -386,6 +548,7 @@ void shouldGetAsNIOByteBuffer() { } @Test + @DisplayName("asNIO: should consolidate multiple buffers into single NIO ByteBuffer") void shouldGetAsNIOByteBufferWithMultipleBuffers() { CompositeByteBuf buf = new CompositeByteBuf(asList( new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2})), @@ -403,25 +566,73 @@ void shouldGetAsNIOByteBufferWithMultipleBuffers() { assertArrayEquals(new byte[]{2, 3, 4, 5, 6}, bytes); } - @Test - void shouldThrowIndexOutOfBoundsExceptionIfReadingOutOfBounds() { - CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); - buf.position(4); + // Test data providers - assertThrows(IndexOutOfBoundsException.class, buf::get); - buf.position(1); - - assertThrows(IndexOutOfBoundsException.class, buf::getInt); - buf.position(0); + static Stream getBuffers() { + return Stream.of( + Arguments.of(Named.of("ByteBufNIO", + asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), + new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4}))))), + Arguments.of(Named.of("NettyByteBuf", + asList(new NettyByteBuf(copiedBuffer(new byte[]{1, 2, 3, 4})), + new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4}))))), + Arguments.of(Named.of("Mixed NIO and NettyByteBuf", + asList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})), + new NettyByteBuf(wrappedBuffer(new byte[]{1, 2, 3, 4}))))) + ); + } - assertThrows(IndexOutOfBoundsException.class, () -> buf.get(new byte[2], 1, 2)); + static Stream bufferProviders() { + TrackingBufferProvider nioBufferProvider = new TrackingBufferProvider(size -> new ByteBufNIO(ByteBuffer.allocate(size))); + PowerOfTwoBufferPool bufferPool = new PowerOfTwoBufferPool(1); + bufferPool.disablePruning(); + TrackingBufferProvider pooledNioBufferProvider = new TrackingBufferProvider(bufferPool); + TrackingBufferProvider nettyBufferProvider = new TrackingBufferProvider(size -> new NettyByteBuf(UnpooledByteBufAllocator.DEFAULT.buffer(size, size))); + return Stream.of( + Arguments.of("NIO", nioBufferProvider), + Arguments.of("pooled NIO", pooledNioBufferProvider), + Arguments.of("Netty", nettyBufferProvider)); } - @Test - void shouldThrowIllegalStateExceptionIfBufferIsClosed() { - CompositeByteBuf buf = new CompositeByteBuf(singletonList(new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 2, 3, 4})))); - buf.release(); + private static final class TrackingBufferProvider implements BufferProvider { + private final BufferProvider decorated; + private final List tracked; + + TrackingBufferProvider(final BufferProvider decorated) { + this.decorated = decorated; + tracked = new ArrayList<>(); + } + + @NonNull + @Override + public ByteBuf getBuffer(final int size) { + ByteBuf result = decorated.getBuffer(size); + tracked.add(result); + return result; + } + + void assertAllAvailable() { + for (ByteBuf buffer : tracked) { + assertTrue(buffer.getReferenceCount() > 0); + if (buffer instanceof ByteBufNIO) { + assertNotNull(buffer.asNIO()); + } else if (buffer instanceof NettyByteBuf) { + assertTrue(((NettyByteBuf) buffer).asByteBuf().refCnt() > 0); + } + } + } + + void assertAllUnavailable() { + for (ByteBuf buffer : tracked) { + assertEquals(0, buffer.getReferenceCount()); + if (buffer instanceof ByteBufNIO) { + assertNull(buffer.asNIO()); + } + if (buffer instanceof NettyByteBuf) { + assertEquals(0, ((NettyByteBuf) buffer).asByteBuf().refCnt()); + } + } + } - assertThrows(IllegalStateException.class, buf::get); } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java index 3aff244ea1e..0176b8e9ad3 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java @@ -56,8 +56,10 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -254,6 +256,55 @@ public void serverHeartbeatFailed(final ServerHeartbeatFailedEvent event) { assertEquals(expectedEvents, events); } + @Test + void closeDuringConnectionShouldNotLeakBuffers() throws Exception { + CountDownLatch connectionStarted = new CountDownLatch(1); + CountDownLatch proceedWithOpen = new CountDownLatch(1); + + InternalConnection mockConnection = mock(InternalConnection.class); + doAnswer(invocation -> { + connectionStarted.countDown(); + assertTrue(proceedWithOpen.await(5, TimeUnit.SECONDS)); + return null; + }).when(mockConnection).open(any()); + + when(mockConnection.getInitialServerDescription()) + .thenReturn(createDefaultServerDescription()); + + InternalConnectionFactory factory = createConnectionFactory(mockConnection); + + monitor = createAndStartMonitor(factory, mock(ServerMonitorListener.class)); + + // Wait for connection to start opening + assertTrue(connectionStarted.await(5, TimeUnit.SECONDS)); + + // Close monitor while connection is opening + monitor.close(); + + // Allow connection to complete + proceedWithOpen.countDown(); + + // Verify no leaks by checking connection was properly closed + monitor.getServerMonitor().join(5000); + } + + @Test + void heartbeatWithNullConnectionDescriptionShouldNotCrash() throws Exception { + InternalConnection mockConnection = mock(InternalConnection.class); + when(mockConnection.getDescription()).thenReturn(null); + when(mockConnection.getInitialServerDescription()) + .thenReturn(createDefaultServerDescription()); + when(mockConnection.isClosed()).thenReturn(false); + + InternalConnectionFactory factory = createConnectionFactory(mockConnection); + monitor = createAndStartMonitor(factory, mock(ServerMonitorListener.class)); + + // Wait a bit for the monitor to run + Thread.sleep(500); + + // Monitor should handle null description gracefully + verify(mockConnection, atLeast(1)).open(any()); + } private InternalConnectionFactory createConnectionFactory(final InternalConnection connection) { InternalConnectionFactory factory = mock(InternalConnectionFactory.class); diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy index e6f6afb02e0..8e7a7b9d78d 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/LoggingCommandEventSenderSpecification.groovy @@ -64,8 +64,10 @@ class LoggingCommandEventSenderSpecification extends Specification { isDebugEnabled() >> debugLoggingEnabled } def operationContext = OPERATION_CONTEXT + def commandMessageDocument = message.getCommandDocument(bsonOutput) + def sender = new LoggingCommandEventSender([] as Set, [] as Set, connectionDescription, commandListener, - operationContext, message, message.getCommandDocument(bsonOutput), + operationContext, message, commandMessageDocument, new StructuredLogger(logger), LoggerSettings.builder().build()) when: @@ -87,6 +89,9 @@ class LoggingCommandEventSenderSpecification extends Specification { database, commandDocument.getFirstKey(), 1, failureException) ]) + cleanup: + commandMessageDocument?.close() + where: debugLoggingEnabled << [true, false] } @@ -110,8 +115,10 @@ class LoggingCommandEventSenderSpecification extends Specification { isDebugEnabled() >> true } def operationContext = OPERATION_CONTEXT + def commandMessageDocument = message.getCommandDocument(bsonOutput) + def sender = new LoggingCommandEventSender([] as Set, [] as Set, connectionDescription, commandListener, - operationContext, message, message.getCommandDocument(bsonOutput), new StructuredLogger(logger), + operationContext, message, commandMessageDocument, new StructuredLogger(logger), LoggerSettings.builder().build()) when: sender.sendStartedEvent() @@ -146,6 +153,9 @@ class LoggingCommandEventSenderSpecification extends Specification { "request ID is ${message.getId()} and the operation ID is ${operationContext.getId()}.") }, failureException) + cleanup: + commandMessageDocument?.close() + where: commandListener << [null, Stub(CommandListener)] } @@ -167,6 +177,7 @@ class LoggingCommandEventSenderSpecification extends Specification { isDebugEnabled() >> true } def operationContext = OPERATION_CONTEXT + def commandMessageDocument = message.getCommandDocument(bsonOutput) def sender = new LoggingCommandEventSender([] as Set, [] as Set, connectionDescription, null, operationContext, message, message.getCommandDocument(bsonOutput), new StructuredLogger(logger), LoggerSettings.builder().build()) @@ -182,6 +193,9 @@ class LoggingCommandEventSenderSpecification extends Specification { "request ID is ${message.getId()} and the operation ID is ${operationContext.getId()}. " + "Command: {\"fake\": {\"\$binary\": {\"base64\": \"${'A' * 967} ..." } + + cleanup: + commandMessageDocument?.close() } def 'should log redacted command with ellipses'() { @@ -201,8 +215,9 @@ class LoggingCommandEventSenderSpecification extends Specification { isDebugEnabled() >> true } def operationContext = OPERATION_CONTEXT + def commandMessageDocument = message.getCommandDocument(bsonOutput) def sender = new LoggingCommandEventSender(['createUser'] as Set, [] as Set, connectionDescription, null, - operationContext, message, message.getCommandDocument(bsonOutput), new StructuredLogger(logger), + operationContext, message, commandMessageDocument, new StructuredLogger(logger), LoggerSettings.builder().build()) when: @@ -215,5 +230,8 @@ class LoggingCommandEventSenderSpecification extends Specification { "${connectionDescription.connectionId.serverValue} to 127.0.0.1:27017. The " + "request ID is ${message.getId()} and the operation ID is ${operationContext.getId()}. Command: {}" } + + cleanup: + commandMessageDocument?.close() } } diff --git a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy index 19bfa994200..6217c220e85 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy @@ -16,6 +16,7 @@ package com.mongodb.internal.session +import com.mongodb.MongoException import com.mongodb.ServerAddress import com.mongodb.connection.ClusterDescription import com.mongodb.connection.ClusterSettings @@ -226,4 +227,34 @@ class ServerSessionPoolSpecification extends Specification { { it instanceof BsonDocumentCodec }, _) >> new BsonDocument() 1 * connection.release() } + + def 'should handle MongoException during endSessions without leaking resources'() { + given: + def connection = Mock(Connection) + def server = Stub(Server) { + getConnection(_) >> connection + } + def cluster = Mock(Cluster) { + getCurrentDescription() >> connectedDescription + } + def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi()) + def sessions = [] + 5.times { sessions.add(pool.get()) } + + for (def cur : sessions) { + pool.release(cur) + } + + when: + pool.close() + + then: + 1 * cluster.selectServer(_, _) >> new ServerTuple(server, connectedDescription.serverDescriptions[0]) + 1 * connection.command('admin', + new BsonDocument('endSessions', new BsonArray(sessions*.getIdentifier())), + { it instanceof NoOpFieldNameValidator }, primaryPreferred(), + { it instanceof BsonDocumentCodec }, _) >> { throw new MongoException("Simulated error") } + 1 * connection.release() + // The test passes if no exception is propagated and connection is still released + } } diff --git a/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java b/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java index dd761234df9..75a60ca382f 100644 --- a/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java +++ b/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java @@ -39,9 +39,9 @@ class DBDecoderAdapter implements Decoder { @Override public DBObject decode(final BsonReader reader, final DecoderContext decoderContext) { - ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(bufferProvider); - BsonBinaryWriter binaryWriter = new BsonBinaryWriter(bsonOutput); - try { + + try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(bufferProvider); + BsonBinaryWriter binaryWriter = new BsonBinaryWriter(bsonOutput)) { binaryWriter.pipe(reader); BufferExposingByteArrayOutputStream byteArrayOutputStream = new BufferExposingByteArrayOutputStream(binaryWriter.getBsonOutput().getSize()); @@ -50,9 +50,6 @@ public DBObject decode(final BsonReader reader, final DecoderContext decoderCont } catch (IOException e) { // impossible with a byte array output stream throw new MongoInternalException("An unlikely IOException thrown.", e); - } finally { - binaryWriter.close(); - bsonOutput.close(); } } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java index 5cf0ea103bd..7a0b016cc98 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java @@ -37,6 +37,8 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; +import java.util.concurrent.atomic.AtomicBoolean; + import static com.mongodb.MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL; import static com.mongodb.MongoException.UNKNOWN_TRANSACTION_COMMIT_RESULT_LABEL; import static com.mongodb.assertions.Assertions.assertNotNull; @@ -52,7 +54,7 @@ final class ClientSessionPublisherImpl extends BaseClientSessionImpl implements private boolean messageSentInCurrentTransaction; private boolean commitInProgress; private TransactionOptions transactionOptions; - + private AtomicBoolean closed = new AtomicBoolean(); ClientSessionPublisherImpl(final ServerSessionPool serverSessionPool, final MongoClientImpl mongoClient, final ClientSessionOptions options, final OperationExecutor executor) { @@ -221,10 +223,18 @@ private void clearTransactionContextOnError(final MongoException e) { @Override public void close() { - if (transactionState == TransactionState.IN) { - Mono.from(abortTransaction()).doFinally(it -> super.close()).subscribe(); - } else { - super.close(); + if (closed.compareAndSet(false, true)) { + if (transactionState == TransactionState.IN) { + Mono.from(abortTransaction()) + .doFinally(it -> { + clearTransactionContext(); + super.close(); + }) + .subscribe(); + } else { + clearTransactionContext(); + super.close(); + } } } diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java index 2881b47e38e..05ca89dd048 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java @@ -24,6 +24,7 @@ import com.mongodb.MongoTimeoutException; import com.mongodb.connection.ClusterType; import com.mongodb.connection.ServerVersion; +import com.mongodb.connection.TransportSettings; import com.mongodb.reactivestreams.client.internal.MongoClientImpl; import org.bson.Document; import org.bson.conversions.Bson; @@ -33,6 +34,7 @@ import java.util.List; import static com.mongodb.ClusterFixture.TIMEOUT_DURATION; +import static com.mongodb.ClusterFixture.getOverriddenTransportSettings; import static com.mongodb.ClusterFixture.getServerApi; import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException; import static java.lang.Thread.sleep; @@ -67,11 +69,18 @@ public static MongoClientSettings.Builder getMongoClientSettingsBuilder() { } public static MongoClientSettings.Builder getMongoClientSettingsBuilder(final ConnectionString connectionString) { - MongoClientSettings.Builder builder = MongoClientSettings.builder(); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applyConnectionString(connectionString); + + TransportSettings overriddenTransportSettings = getOverriddenTransportSettings(); + if (overriddenTransportSettings != null) { + builder.transportSettings(overriddenTransportSettings); + } + if (getServerApi() != null) { builder.serverApi(getServerApi()); } - return builder.applyConnectionString(connectionString); + return builder; } public static String getDefaultDatabaseName() { @@ -164,6 +173,11 @@ public static synchronized ConnectionString getConnectionString() { public static MongoClientSettings.Builder getMongoClientBuilderFromConnectionString() { MongoClientSettings.Builder builder = MongoClientSettings.builder() .applyConnectionString(getConnectionString()); + + TransportSettings overriddenTransportSettings = getOverriddenTransportSettings(); + if (overriddenTransportSettings != null) { + builder.transportSettings(overriddenTransportSettings); + } if (getServerApi() != null) { builder.serverApi(getServerApi()); } diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/NettySettingsSmokeTestSpecification.groovy b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/NettySettingsSmokeTestSpecification.groovy deleted file mode 100644 index 7e35e9a183a..00000000000 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/NettySettingsSmokeTestSpecification.groovy +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.reactivestreams.client - -import com.mongodb.MongoClientSettings -import com.mongodb.connection.TransportSettings -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.oio.OioSocketChannel -import org.bson.Document -import reactor.core.publisher.Mono - -import static Fixture.getMongoClientBuilderFromConnectionString -import static com.mongodb.ClusterFixture.TIMEOUT_DURATION - -@SuppressWarnings('deprecation') -class NettySettingsSmokeTestSpecification extends FunctionalSpecification { - - private MongoClient mongoClient - - def 'should allow a custom Event Loop Group and Socket Channel'() { - given: - def eventLoopGroup = new OioEventLoopGroup() - def nettySettings = TransportSettings.nettyBuilder() - .eventLoopGroup(eventLoopGroup) - .socketChannelClass(OioSocketChannel) - .build() - MongoClientSettings settings = getMongoClientBuilderFromConnectionString() - .transportSettings(nettySettings).build() - def document = new Document('a', 1) - - when: - mongoClient = MongoClients.create(settings) - def collection = mongoClient.getDatabase(databaseName).getCollection(collectionName) - - - then: - Mono.from(collection.insertOne(document)).block(TIMEOUT_DURATION) - - then: 'The count is one' - Mono.from(collection.countDocuments()).block(TIMEOUT_DURATION) == 1 - - cleanup: - mongoClient?.close() - } - -} diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java index 3682bd64ff0..764f3dbaa47 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java @@ -93,7 +93,9 @@ public void shouldCreateServerSessionOnlyAfterConnectionCheckout() throws Interr .addCommandListener(new CommandListener() { @Override public void commandStarted(final CommandStartedEvent event) { - lsidSet.add(event.getCommand().getDocument("lsid")); + if (event.getCommand().containsKey("lsid")) { + lsidSet.add(event.getCommand().getDocument("lsid").toBsonDocument()); + } } }) .build())) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java index cf003078f04..1fc2d18a1fb 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java @@ -311,6 +311,7 @@ public void cleanUp() { if (testDef != null) { postCleanUp(testDef); } + System.gc(); } protected void postCleanUp(final TestDef testDef) { diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 8a08c34f213..c75eb6d7000 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -18,7 +18,7 @@ aws-sdk-v2 = "2.30.31" graal-sdk = "24.0.0" jna = "5.11.0" jnr-unixsocket = "0.38.17" -netty-bom = "4.1.87.Final" +netty-bom = "4.2.9.Final" project-reactor-bom = "2022.0.0" reactive-streams = "1.0.4" snappy = "1.1.10.3" diff --git a/mongodb-crypt/build.gradle.kts b/mongodb-crypt/build.gradle.kts index 812753151d5..08a59aa091d 100644 --- a/mongodb-crypt/build.gradle.kts +++ b/mongodb-crypt/build.gradle.kts @@ -54,7 +54,7 @@ val jnaLibsPath: String = System.getProperty("jnaLibsPath", "${jnaResourcesDir}$ val jnaResources: String = System.getProperty("jna.library.path", jnaLibsPath) // Download jnaLibs that match the git tag or revision to jnaResourcesBuildDir -val downloadRevision = "1.15.1" +val downloadRevision = "1.15.2" val binariesArchiveName = "libmongocrypt-java.tar.gz" /**