Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions driver-core/src/main/com/mongodb/client/model/Aggregates.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import static com.mongodb.internal.Iterables.concat;
import static com.mongodb.internal.client.model.Util.sizeAtLeast;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;

/**
* Builders for aggregation pipeline stages.
Expand Down Expand Up @@ -1040,6 +1041,57 @@ public static Bson vectorSearch(
return new VectorSearchBson(path, queryVector, index, limit, options);
}

/**
* Creates a {@code $rerank} pipeline stage supported by MongoDB Atlas.
* You may use the {@code $meta: "score"} expression to extract the relevance score
* assigned to each reranked document.
*
* @param query The query to rerank against, created via {@link RerankQuery#rerankQuery(String)}.
* @param path The document field to send to the reranker.
* @param numDocsToRerank The maximum number of documents to rerank (1-1000).
* @param model The reranking model name. Accepted values:
* {@code "rerank-2.5"}, {@code "rerank-2.5-lite"}, {@code "rerank-2"}, {@code "rerank-2-lite"}.
* @return The {@code $rerank} pipeline stage.
* @mongodb.server.release 8.3
* @since 5.8
*/
@Beta(Reason.SERVER)
public static Bson rerank(
final RerankQuery query,
final String path,
final int numDocsToRerank,
final String model) {
notNull("path", path);
return rerank(query, singletonList(path), numDocsToRerank, model);
}

/**
* Creates a {@code $rerank} pipeline stage supported by MongoDB Atlas.
* You may use the {@code $meta: "score"} expression to extract the relevance score
* assigned to each reranked document.
*
* @param query The query to rerank against, created via {@link RerankQuery#rerankQuery(String)}.
* @param paths The document field(s) to send to the reranker.
* @param numDocsToRerank The maximum number of documents to rerank (1-1000).
* @param model The reranking model name. Accepted values:
* {@code "rerank-2.5"}, {@code "rerank-2.5-lite"}, {@code "rerank-2"}, {@code "rerank-2-lite"}.
* @return The {@code $rerank} pipeline stage.
* @mongodb.server.release 8.3
* @since 5.8
*/
@Beta(Reason.SERVER)
public static Bson rerank(
final RerankQuery query,
final List<String> paths,
final int numDocsToRerank,
final String model) {
notNull("query", query);
notNull("paths", paths);
isTrueArgument("paths must not be empty", !paths.isEmpty());
notNull("model", model);
return new RerankBson(query, paths, numDocsToRerank, model);
}

/**
* Creates an $unset pipeline stage that removes/excludes fields from documents
*
Expand Down Expand Up @@ -2290,4 +2342,38 @@ public String toString() {
+ '}';
}
}

private static class RerankBson implements Bson {
private final RerankQuery query;
private final List<String> paths;
private final int numDocsToRerank;
private final String model;

RerankBson(final RerankQuery query, final List<String> paths, final int numDocsToRerank,
final String model) {
this.query = query;
this.paths = paths;
this.numDocsToRerank = numDocsToRerank;
this.model = model;
}

@Override
public <TDocument> BsonDocument toBsonDocument(final Class<TDocument> documentClass, final CodecRegistry codecRegistry) {
Document specificationDoc = new Document("query", query)
.append("path", paths.size() == 1 ? paths.get(0) : paths)
.append("numDocsToRerank", numDocsToRerank)
.append("model", model);
return new Document("$rerank", specificationDoc).toBsonDocument(documentClass, codecRegistry);
}

@Override
public String toString() {
return "Stage{name=$rerank"
+ ", query=" + query
+ ", paths=" + paths
+ ", numDocsToRerank=" + numDocsToRerank
+ ", model=" + model
+ '}';
}
}
}
82 changes: 82 additions & 0 deletions driver-core/src/main/com/mongodb/client/model/RerankQuery.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2008-present MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mongodb.client.model;

import org.bson.BsonDocument;
import org.bson.BsonString;
import org.bson.annotations.Beta;
import org.bson.annotations.Reason;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.conversions.Bson;

import static com.mongodb.assertions.Assertions.notNull;

/**
* Represents a query for the {@code $rerank} aggregation pipeline stage.
* <p>
* Use {@link #rerankQuery(String)} for a simple text query, or
* {@link #rerankQuery(Bson)} to specify the full query document directly
* (e.g., for future modalities like imageURL or videoURL).
*
* @mongodb.server.release 8.3
* @since 5.8
*/
@Beta(Reason.SERVER)
public final class RerankQuery implements Bson {
private final Bson query;

private RerankQuery(final Bson query) {
this.query = query;
}

/**
* Creates a rerank query with the specified text.
* <p>
* This is a convenience for {@code rerankQuery(new Document("text", text))}.
*
* @param text the query text to rerank against.
* @return a new {@link RerankQuery}
*/
public static RerankQuery rerankQuery(final String text) {
notNull("text", text);
return new RerankQuery(new BsonDocument("text", new BsonString(text)));
}

/**
* Creates a rerank query from a full query document.
* <p>
* Use this overload for future query modalities (e.g., imageURL, videoURL)
* or to pass additional fields alongside text.
*
* @param query the query document.
* @return a new {@link RerankQuery}
*/
public static RerankQuery rerankQuery(final Bson query) {
notNull("query", query);
return new RerankQuery(query);
}

@Override
public <TDocument> BsonDocument toBsonDocument(final Class<TDocument> documentClass, final CodecRegistry codecRegistry) {
return query.toBsonDocument(documentClass, codecRegistry);
}

@Override
public String toString() {
return "RerankQuery{" + query + '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.junit.jupiter.params.provider.MethodSource;

import java.math.RoundingMode;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;
Expand All @@ -43,8 +42,10 @@
import static com.mongodb.client.model.Accumulators.percentile;
import static com.mongodb.client.model.Aggregates.geoNear;
import static com.mongodb.client.model.Aggregates.group;
import static com.mongodb.client.model.Aggregates.rerank;
import static com.mongodb.client.model.Aggregates.unset;
import static com.mongodb.client.model.Aggregates.vectorSearch;
import static com.mongodb.client.model.RerankQuery.rerankQuery;
import static com.mongodb.client.model.GeoNearOptions.geoNearOptions;
import static com.mongodb.client.model.Sorts.ascending;
import static com.mongodb.client.model.Windows.Bound.UNBOUNDED;
Expand Down Expand Up @@ -260,17 +261,17 @@ public void testDocuments() {
"{$documents: [{a: 1, b: {$add: [1, 1]}}, {a: 3, b: 4}]}",
stage);

List<Bson> pipeline = Arrays.asList(stage);
List<Bson> pipeline = asList(stage);
getCollectionHelper().aggregateDb(pipeline);

assertEquals(
parseToList("[{a: 1, b: 2}, {a: 3, b: 4}]"),
getCollectionHelper().aggregateDb(pipeline));

// accepts lists of Documents and BsonDocuments
List<BsonDocument> documents = Arrays.asList(BsonDocument.parse("{a: 1, b: 2}"));
List<BsonDocument> documents = asList(BsonDocument.parse("{a: 1, b: 2}"));
assertPipeline("{$documents: [{a: 1, b: 2}]}", Aggregates.documents(documents));
List<BsonDocument> bsonDocuments = Arrays.asList(BsonDocument.parse("{a: 1, b: 2}"));
List<BsonDocument> bsonDocuments = asList(BsonDocument.parse("{a: 1, b: 2}"));
assertPipeline("{$documents: [{a: 1, b: 2}]}", Aggregates.documents(bsonDocuments));
}

Expand All @@ -281,13 +282,13 @@ public void testDocumentsLookup() {
getCollectionHelper().insertDocuments("[{_id: 1, a: 8}, {_id: 2, a: 9}]");
Bson documentsStage = Aggregates.documents(asList(Document.parse("{a: 5}")));

Bson lookupStage = Aggregates.lookup(null, Arrays.asList(documentsStage), "added");
Bson lookupStage = Aggregates.lookup(null, asList(documentsStage), "added");
assertPipeline(
"{'$lookup': {'pipeline': [{'$documents': [{'a': 5}]}], 'as': 'added'}}",
lookupStage);
assertEquals(
parseToList("[{_id:1, a:8, added: [{a: 5}]}, {_id:2, a:9, added: [{a: 5}]}]"),
getCollectionHelper().aggregate(Arrays.asList(lookupStage)));
getCollectionHelper().aggregate(asList(lookupStage)));
}

@Test
Expand Down Expand Up @@ -374,4 +375,82 @@ public void testExactVectorSearchWithQueryObject() {
exactVectorSearchOptions()
));
}

@Test
public void testRerankWithSinglePath() {
assertPipeline(
"{"
+ " '$rerank': {"
+ " 'query': {'text': 'machine learning tutorials'},"
+ " 'path': 'content',"
+ " 'numDocsToRerank': 25,"
+ " 'model': 'rerank-2.5'"
+ " }"
+ "}",
rerank(
rerankQuery("machine learning tutorials"),
"content",
25,
"rerank-2.5"
));
}

@Test
public void testRerankWithMultiplePaths() {
assertPipeline(
"{"
+ " '$rerank': {"
+ " 'query': {'text': 'machine learning tutorials'},"
+ " 'path': ['content', 'title'],"
+ " 'numDocsToRerank': 50,"
+ " 'model': 'rerank-2.5-lite'"
+ " }"
+ "}",
rerank(
rerankQuery("machine learning tutorials"),
asList("content", "title"),
50,
"rerank-2.5-lite"
));
}

@Test
public void testRerankWithBsonQuery() {
assertPipeline(
"{"
+ " '$rerank': {"
+ " 'query': {'text': 'machine learning tutorials', 'imageURL': 'https://example.com/img.png'},"
+ " 'path': 'content',"
+ " 'numDocsToRerank': 25,"
+ " 'model': 'rerank-2.5'"
+ " }"
+ "}",
rerank(
rerankQuery(new Document("text", "machine learning tutorials")
.append("imageURL", "https://example.com/img.png")),
"content",
25,
"rerank-2.5"
));
}

@Test
public void testRerankWithMultiplePathsAndBsonQuery() {
assertPipeline(
"{"
+ " '$rerank': {"
+ " 'query': {'text': 'machine learning tutorials', 'imageURL': 'https://example.com/img.png'},"
+ " 'path': ['content', 'title'],"
+ " 'numDocsToRerank': 100,"
+ " 'model': 'rerank-2'"
+ " }"
+ "}",
rerank(
rerankQuery(new Document("text", "machine learning tutorials")
.append("imageURL", "https://example.com/img.png")),
asList("content", "title"),
100,
"rerank-2"
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.mongodb.client.model.search.FieldSearchPath

import scala.collection.JavaConverters._
import com.mongodb.client.model.{ Aggregates => JAggregates }
import com.mongodb.client.model.RerankQuery
import org.mongodb.scala.MongoNamespace
import org.mongodb.scala.bson.conversions.Bson
import org.mongodb.scala.model.densify.{ DensifyOptions, DensifyRange }
Expand Down Expand Up @@ -746,6 +747,50 @@ object Aggregates {
): Bson =
JAggregates.vectorSearch(path, queryVector.asJava, index, limit, options)

/**
* Creates a `\$rerank` pipeline stage supported by MongoDB Atlas.
* You may use the `\$meta: "score"` expression to extract the relevance score
* assigned to each reranked document.
*
* @param query The query to rerank against.
* @param path The document field to send to the reranker.
* @param numDocsToRerank The maximum number of documents to rerank (1-1000).
* @param model The reranking model name.
* @return The `\$rerank` pipeline stage.
* @note Requires MongoDB on Atlas 8.3 or greater
* @since 5.8
*/
@Beta(Array(Reason.SERVER))
def rerank(
query: RerankQuery,
path: String,
numDocsToRerank: Int,
model: String
): Bson =
JAggregates.rerank(query, path, numDocsToRerank, model)

/**
* Creates a `\$rerank` pipeline stage supported by MongoDB Atlas.
* You may use the `\$meta: "score"` expression to extract the relevance score
* assigned to each reranked document.
*
* @param query The query to rerank against.
* @param paths The document field(s) to send to the reranker.
* @param numDocsToRerank The maximum number of documents to rerank (1-1000).
* @param model The reranking model name.
* @return The `\$rerank` pipeline stage.
* @note Requires MongoDB on Atlas 8.3 or greater
* @since 5.8
*/
@Beta(Array(Reason.SERVER))
def rerank(
query: RerankQuery,
paths: Seq[String],
numDocsToRerank: Int,
model: String
): Bson =
JAggregates.rerank(query, paths.toList.asJava, numDocsToRerank, model)

/**
* Creates an `\$unset` pipeline stage that removes/excludes fields from documents
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,8 @@ package object model {

type GeoNearOptions = com.mongodb.client.model.GeoNearOptions

type RerankQuery = com.mongodb.client.model.RerankQuery

/**
* @see `QuantileMethod.approximate()`
*/
Expand Down
Loading