From 25e934ca6cd2510b759ed02d6cf74ade9bff4ea3 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 6 May 2026 15:38:38 +0100 Subject: [PATCH] Add $rerank aggregation stage support Adds builder support for the $rerank pipeline stage (MongoDB 8.3, Atlas only). API: - RerankQuery: query object with text shorthand or full Bson for future modalities - Aggregates.rerank(): 2 overloads (single path, multi path) - Scala wrappers and type alias in Aggregates.scala JAVA-6052 --- .../com/mongodb/client/model/Aggregates.java | 86 ++++++++++++++++++ .../com/mongodb/client/model/RerankQuery.java | 82 +++++++++++++++++ .../mongodb/client/model/AggregatesTest.java | 91 +++++++++++++++++-- .../org/mongodb/scala/model/Aggregates.scala | 45 +++++++++ .../org/mongodb/scala/model/package.scala | 2 + .../mongodb/scala/model/AggregatesSpec.scala | 45 +++++++++ 6 files changed, 345 insertions(+), 6 deletions(-) create mode 100644 driver-core/src/main/com/mongodb/client/model/RerankQuery.java diff --git a/driver-core/src/main/com/mongodb/client/model/Aggregates.java b/driver-core/src/main/com/mongodb/client/model/Aggregates.java index 6a5950ab560..ce55837b967 100644 --- a/driver-core/src/main/com/mongodb/client/model/Aggregates.java +++ b/driver-core/src/main/com/mongodb/client/model/Aggregates.java @@ -59,6 +59,7 @@ import static com.mongodb.internal.Iterables.concat; import static com.mongodb.internal.client.model.Util.sizeAtLeast; import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; /** * Builders for aggregation pipeline stages. @@ -1040,6 +1041,57 @@ public static Bson vectorSearch( return new VectorSearchBson(path, queryVector, index, limit, options); } + /** + * Creates a {@code $rerank} pipeline stage supported by MongoDB Atlas. + * You may use the {@code $meta: "score"} expression to extract the relevance score + * assigned to each reranked document. + * + * @param query The query to rerank against, created via {@link RerankQuery#rerankQuery(String)}. + * @param path The document field to send to the reranker. + * @param numDocsToRerank The maximum number of documents to rerank (1-1000). + * @param model The reranking model name. Accepted values: + * {@code "rerank-2.5"}, {@code "rerank-2.5-lite"}, {@code "rerank-2"}, {@code "rerank-2-lite"}. + * @return The {@code $rerank} pipeline stage. + * @mongodb.server.release 8.3 + * @since 5.8 + */ + @Beta(Reason.SERVER) + public static Bson rerank( + final RerankQuery query, + final String path, + final int numDocsToRerank, + final String model) { + notNull("path", path); + return rerank(query, singletonList(path), numDocsToRerank, model); + } + + /** + * Creates a {@code $rerank} pipeline stage supported by MongoDB Atlas. + * You may use the {@code $meta: "score"} expression to extract the relevance score + * assigned to each reranked document. + * + * @param query The query to rerank against, created via {@link RerankQuery#rerankQuery(String)}. + * @param paths The document field(s) to send to the reranker. + * @param numDocsToRerank The maximum number of documents to rerank (1-1000). + * @param model The reranking model name. Accepted values: + * {@code "rerank-2.5"}, {@code "rerank-2.5-lite"}, {@code "rerank-2"}, {@code "rerank-2-lite"}. + * @return The {@code $rerank} pipeline stage. + * @mongodb.server.release 8.3 + * @since 5.8 + */ + @Beta(Reason.SERVER) + public static Bson rerank( + final RerankQuery query, + final List paths, + final int numDocsToRerank, + final String model) { + notNull("query", query); + notNull("paths", paths); + isTrueArgument("paths must not be empty", !paths.isEmpty()); + notNull("model", model); + return new RerankBson(query, paths, numDocsToRerank, model); + } + /** * Creates an $unset pipeline stage that removes/excludes fields from documents * @@ -2290,4 +2342,38 @@ public String toString() { + '}'; } } + + private static class RerankBson implements Bson { + private final RerankQuery query; + private final List paths; + private final int numDocsToRerank; + private final String model; + + RerankBson(final RerankQuery query, final List paths, final int numDocsToRerank, + final String model) { + this.query = query; + this.paths = paths; + this.numDocsToRerank = numDocsToRerank; + this.model = model; + } + + @Override + public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { + Document specificationDoc = new Document("query", query) + .append("path", paths.size() == 1 ? paths.get(0) : paths) + .append("numDocsToRerank", numDocsToRerank) + .append("model", model); + return new Document("$rerank", specificationDoc).toBsonDocument(documentClass, codecRegistry); + } + + @Override + public String toString() { + return "Stage{name=$rerank" + + ", query=" + query + + ", paths=" + paths + + ", numDocsToRerank=" + numDocsToRerank + + ", model=" + model + + '}'; + } + } } diff --git a/driver-core/src/main/com/mongodb/client/model/RerankQuery.java b/driver-core/src/main/com/mongodb/client/model/RerankQuery.java new file mode 100644 index 00000000000..25027c07dc1 --- /dev/null +++ b/driver-core/src/main/com/mongodb/client/model/RerankQuery.java @@ -0,0 +1,82 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.model; + +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.annotations.Beta; +import org.bson.annotations.Reason; +import org.bson.codecs.configuration.CodecRegistry; +import org.bson.conversions.Bson; + +import static com.mongodb.assertions.Assertions.notNull; + +/** + * Represents a query for the {@code $rerank} aggregation pipeline stage. + *

+ * Use {@link #rerankQuery(String)} for a simple text query, or + * {@link #rerankQuery(Bson)} to specify the full query document directly + * (e.g., for future modalities like imageURL or videoURL). + * + * @mongodb.server.release 8.3 + * @since 5.8 + */ +@Beta(Reason.SERVER) +public final class RerankQuery implements Bson { + private final Bson query; + + private RerankQuery(final Bson query) { + this.query = query; + } + + /** + * Creates a rerank query with the specified text. + *

+ * This is a convenience for {@code rerankQuery(new Document("text", text))}. + * + * @param text the query text to rerank against. + * @return a new {@link RerankQuery} + */ + public static RerankQuery rerankQuery(final String text) { + notNull("text", text); + return new RerankQuery(new BsonDocument("text", new BsonString(text))); + } + + /** + * Creates a rerank query from a full query document. + *

+ * Use this overload for future query modalities (e.g., imageURL, videoURL) + * or to pass additional fields alongside text. + * + * @param query the query document. + * @return a new {@link RerankQuery} + */ + public static RerankQuery rerankQuery(final Bson query) { + notNull("query", query); + return new RerankQuery(query); + } + + @Override + public BsonDocument toBsonDocument(final Class documentClass, final CodecRegistry codecRegistry) { + return query.toBsonDocument(documentClass, codecRegistry); + } + + @Override + public String toString() { + return "RerankQuery{" + query + '}'; + } +} diff --git a/driver-core/src/test/functional/com/mongodb/client/model/AggregatesTest.java b/driver-core/src/test/functional/com/mongodb/client/model/AggregatesTest.java index 7fd01712ea3..5cb70d4e2ef 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/AggregatesTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/AggregatesTest.java @@ -33,7 +33,6 @@ import org.junit.jupiter.params.provider.MethodSource; import java.math.RoundingMode; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Stream; @@ -43,8 +42,10 @@ import static com.mongodb.client.model.Accumulators.percentile; import static com.mongodb.client.model.Aggregates.geoNear; import static com.mongodb.client.model.Aggregates.group; +import static com.mongodb.client.model.Aggregates.rerank; import static com.mongodb.client.model.Aggregates.unset; import static com.mongodb.client.model.Aggregates.vectorSearch; +import static com.mongodb.client.model.RerankQuery.rerankQuery; import static com.mongodb.client.model.GeoNearOptions.geoNearOptions; import static com.mongodb.client.model.Sorts.ascending; import static com.mongodb.client.model.Windows.Bound.UNBOUNDED; @@ -260,7 +261,7 @@ public void testDocuments() { "{$documents: [{a: 1, b: {$add: [1, 1]}}, {a: 3, b: 4}]}", stage); - List pipeline = Arrays.asList(stage); + List pipeline = asList(stage); getCollectionHelper().aggregateDb(pipeline); assertEquals( @@ -268,9 +269,9 @@ public void testDocuments() { getCollectionHelper().aggregateDb(pipeline)); // accepts lists of Documents and BsonDocuments - List documents = Arrays.asList(BsonDocument.parse("{a: 1, b: 2}")); + List documents = asList(BsonDocument.parse("{a: 1, b: 2}")); assertPipeline("{$documents: [{a: 1, b: 2}]}", Aggregates.documents(documents)); - List bsonDocuments = Arrays.asList(BsonDocument.parse("{a: 1, b: 2}")); + List bsonDocuments = asList(BsonDocument.parse("{a: 1, b: 2}")); assertPipeline("{$documents: [{a: 1, b: 2}]}", Aggregates.documents(bsonDocuments)); } @@ -281,13 +282,13 @@ public void testDocumentsLookup() { getCollectionHelper().insertDocuments("[{_id: 1, a: 8}, {_id: 2, a: 9}]"); Bson documentsStage = Aggregates.documents(asList(Document.parse("{a: 5}"))); - Bson lookupStage = Aggregates.lookup(null, Arrays.asList(documentsStage), "added"); + Bson lookupStage = Aggregates.lookup(null, asList(documentsStage), "added"); assertPipeline( "{'$lookup': {'pipeline': [{'$documents': [{'a': 5}]}], 'as': 'added'}}", lookupStage); assertEquals( parseToList("[{_id:1, a:8, added: [{a: 5}]}, {_id:2, a:9, added: [{a: 5}]}]"), - getCollectionHelper().aggregate(Arrays.asList(lookupStage))); + getCollectionHelper().aggregate(asList(lookupStage))); } @Test @@ -374,4 +375,82 @@ public void testExactVectorSearchWithQueryObject() { exactVectorSearchOptions() )); } + + @Test + public void testRerankWithSinglePath() { + assertPipeline( + "{" + + " '$rerank': {" + + " 'query': {'text': 'machine learning tutorials'}," + + " 'path': 'content'," + + " 'numDocsToRerank': 25," + + " 'model': 'rerank-2.5'" + + " }" + + "}", + rerank( + rerankQuery("machine learning tutorials"), + "content", + 25, + "rerank-2.5" + )); + } + + @Test + public void testRerankWithMultiplePaths() { + assertPipeline( + "{" + + " '$rerank': {" + + " 'query': {'text': 'machine learning tutorials'}," + + " 'path': ['content', 'title']," + + " 'numDocsToRerank': 50," + + " 'model': 'rerank-2.5-lite'" + + " }" + + "}", + rerank( + rerankQuery("machine learning tutorials"), + asList("content", "title"), + 50, + "rerank-2.5-lite" + )); + } + + @Test + public void testRerankWithBsonQuery() { + assertPipeline( + "{" + + " '$rerank': {" + + " 'query': {'text': 'machine learning tutorials', 'imageURL': 'https://example.com/img.png'}," + + " 'path': 'content'," + + " 'numDocsToRerank': 25," + + " 'model': 'rerank-2.5'" + + " }" + + "}", + rerank( + rerankQuery(new Document("text", "machine learning tutorials") + .append("imageURL", "https://example.com/img.png")), + "content", + 25, + "rerank-2.5" + )); + } + + @Test + public void testRerankWithMultiplePathsAndBsonQuery() { + assertPipeline( + "{" + + " '$rerank': {" + + " 'query': {'text': 'machine learning tutorials', 'imageURL': 'https://example.com/img.png'}," + + " 'path': ['content', 'title']," + + " 'numDocsToRerank': 100," + + " 'model': 'rerank-2'" + + " }" + + "}", + rerank( + rerankQuery(new Document("text", "machine learning tutorials") + .append("imageURL", "https://example.com/img.png")), + asList("content", "title"), + 100, + "rerank-2" + )); + } } diff --git a/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala b/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala index c7b8d120cf7..60125185a24 100644 --- a/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala +++ b/driver-scala/src/main/scala/org/mongodb/scala/model/Aggregates.scala @@ -22,6 +22,7 @@ import com.mongodb.client.model.search.FieldSearchPath import scala.collection.JavaConverters._ import com.mongodb.client.model.{ Aggregates => JAggregates } +import com.mongodb.client.model.RerankQuery import org.mongodb.scala.MongoNamespace import org.mongodb.scala.bson.conversions.Bson import org.mongodb.scala.model.densify.{ DensifyOptions, DensifyRange } @@ -746,6 +747,50 @@ object Aggregates { ): Bson = JAggregates.vectorSearch(path, queryVector.asJava, index, limit, options) + /** + * Creates a `\$rerank` pipeline stage supported by MongoDB Atlas. + * You may use the `\$meta: "score"` expression to extract the relevance score + * assigned to each reranked document. + * + * @param query The query to rerank against. + * @param path The document field to send to the reranker. + * @param numDocsToRerank The maximum number of documents to rerank (1-1000). + * @param model The reranking model name. + * @return The `\$rerank` pipeline stage. + * @note Requires MongoDB on Atlas 8.3 or greater + * @since 5.8 + */ + @Beta(Array(Reason.SERVER)) + def rerank( + query: RerankQuery, + path: String, + numDocsToRerank: Int, + model: String + ): Bson = + JAggregates.rerank(query, path, numDocsToRerank, model) + + /** + * Creates a `\$rerank` pipeline stage supported by MongoDB Atlas. + * You may use the `\$meta: "score"` expression to extract the relevance score + * assigned to each reranked document. + * + * @param query The query to rerank against. + * @param paths The document field(s) to send to the reranker. + * @param numDocsToRerank The maximum number of documents to rerank (1-1000). + * @param model The reranking model name. + * @return The `\$rerank` pipeline stage. + * @note Requires MongoDB on Atlas 8.3 or greater + * @since 5.8 + */ + @Beta(Array(Reason.SERVER)) + def rerank( + query: RerankQuery, + paths: Seq[String], + numDocsToRerank: Int, + model: String + ): Bson = + JAggregates.rerank(query, paths.toList.asJava, numDocsToRerank, model) + /** * Creates an `\$unset` pipeline stage that removes/excludes fields from documents * diff --git a/driver-scala/src/main/scala/org/mongodb/scala/model/package.scala b/driver-scala/src/main/scala/org/mongodb/scala/model/package.scala index 0d23a38c2e8..2363d25244a 100644 --- a/driver-scala/src/main/scala/org/mongodb/scala/model/package.scala +++ b/driver-scala/src/main/scala/org/mongodb/scala/model/package.scala @@ -987,6 +987,8 @@ package object model { type GeoNearOptions = com.mongodb.client.model.GeoNearOptions + type RerankQuery = com.mongodb.client.model.RerankQuery + /** * @see `QuantileMethod.approximate()` */ diff --git a/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala b/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala index d5a38ad7bca..89a6e939bf2 100644 --- a/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala +++ b/driver-scala/src/test/scala/org/mongodb/scala/model/AggregatesSpec.scala @@ -37,6 +37,7 @@ import org.mongodb.scala.model.geojson.{ Point, Position } import org.mongodb.scala.model.search.SearchCount.total import org.mongodb.scala.model.search.SearchFacet.stringFacet import org.mongodb.scala.model.search.SearchHighlight.paths +import com.mongodb.client.model.RerankQuery import org.mongodb.scala.model.search.SearchCollector import org.mongodb.scala.model.search.SearchOperator.exists import org.mongodb.scala.model.search.SearchOptions.searchOptions @@ -816,6 +817,50 @@ class AggregatesSpec extends BaseSpec { ) } + it should "render $rerank with single path" in { + toBson( + Aggregates.rerank( + RerankQuery.rerankQuery("machine learning"), + "content", + 25, + "rerank-2.5" + ) + ) should equal( + Document( + """{ + "$rerank": { + "query": {"text": "machine learning"}, + "path": "content", + "numDocsToRerank": 25, + "model": "rerank-2.5" + } + }""" + ) + ) + } + + it should "render $rerank with multiple paths" in { + toBson( + Aggregates.rerank( + RerankQuery.rerankQuery("machine learning"), + List("content", "title"), + 50, + "rerank-2.5-lite" + ) + ) should equal( + Document( + """{ + "$rerank": { + "query": {"text": "machine learning"}, + "path": ["content", "title"], + "numDocsToRerank": 50, + "model": "rerank-2.5-lite" + } + }""" + ) + ) + } + it should "render $unset" in { toBson( Aggregates.unset("title", "author.first")