Skip to content
Draft
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@
import io.stargate.sgv2.jsonapi.config.feature.ApiFeature;
import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures;
import io.stargate.sgv2.jsonapi.config.feature.FeaturesConfig;
import io.stargate.sgv2.jsonapi.logging.LoggingMDCContext;
import io.stargate.sgv2.jsonapi.metrics.CommandFeatures;
import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter;
import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.*;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider;
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProviderFactory;
import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProviderFactory;
import io.stargate.sgv2.jsonapi.service.schema.DatabaseSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.KeyspaceSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObjectType;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.tables.TableSchemaObject;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
Expand All @@ -29,13 +36,16 @@
* context for a specific request call {@link BuilderSupplier#getBuilder(SchemaObject)} to get a
* {@link BuilderSupplier.Builder} to configure the context for the request.
*
* <p>
* <p><b>NOTE:</b> When {@link BuilderSupplier.Builder#build()} is called it will call {@link
* #addToMDC()} so that the context is added to the logging MDC for the duration of the request. The
* context must be closed via {@link #close()} to remove it from the MDC, this should be done at the
* last possible time in the resource handler so all log messages have the context.
*
* @param <SchemaT> The schema object type that this context is for. There are times we need to lock
* this down to the specific type, if so use the "as" methods such as {@link
* CommandContext#asCollectionContext()}
*/
public class CommandContext<SchemaT extends SchemaObject> {
public class CommandContext<SchemaT extends SchemaObject> implements LoggingMDCContext {

// Common for all instances
private final JsonProcessingMetricsReporter jsonProcessingMetricsReporter;
Expand All @@ -47,11 +57,15 @@ public class CommandContext<SchemaT extends SchemaObject> {

// Request specific
private final SchemaT schemaObject;
private final RequestTracing requestTracing;
private final RequestContext requestContext;
private final EmbeddingProvider
embeddingProvider; // to be removed later, this is a single provider
private final String commandName; // TODO: remove the command name, but it is used in 14 places
private final RequestContext requestContext;
private RequestTracing requestTracing;

// both per request list of objects that want to update the logging MDC context,
// add to this list in the ctor. See {@link #addToMDC()} and {@link #removeFromMDC()}
private final List<LoggingMDCContext> loggingMDCContexts = new ArrayList<>();

// see accessors
private FindAndRerankCommand.HybridLimits hybridLimits;
Expand All @@ -77,19 +91,23 @@ private CommandContext(
RerankingProviderFactory rerankingProviderFactory,
MeterRegistry meterRegistry) {

this.schemaObject = schemaObject;
this.embeddingProvider = embeddingProvider;
this.commandName = commandName;
this.requestContext = requestContext;

this.jsonProcessingMetricsReporter = jsonProcessingMetricsReporter;
// Common for all instances
this.cqlSessionCache = cqlSessionCache;
this.commandConfig = commandConfig;
this.embeddingProviderFactory = embeddingProviderFactory;
this.jsonProcessingMetricsReporter = jsonProcessingMetricsReporter;
this.meterRegistry = meterRegistry;
this.rerankingProviderFactory = rerankingProviderFactory;

// Request specific
this.embeddingProvider = embeddingProvider; // to be removed later, this is a single provider
this.requestContext = requestContext;
this.schemaObject = schemaObject;
this.commandName = commandName; // TODO: remove the command name, but it is used in 14 places
this.apiFeatures = apiFeatures;
this.meterRegistry = meterRegistry;

this.loggingMDCContexts.add(this.requestContext);
this.loggingMDCContexts.add(this.schemaObject.identifier());

var anyTracing =
apiFeatures().isFeatureEnabled(ApiFeature.REQUEST_TRACING)
Expand Down Expand Up @@ -191,41 +209,59 @@ public MeterRegistry meterRegistry() {
}

public boolean isCollectionContext() {
return schemaObject().type() == CollectionSchemaObject.TYPE;
return schemaObject().type() == SchemaObjectType.COLLECTION;
}

@SuppressWarnings("unchecked")
public CommandContext<CollectionSchemaObject> asCollectionContext() {
checkSchemaObjectType(CollectionSchemaObject.TYPE);
checkSchemaObjectType(SchemaObjectType.COLLECTION);
return (CommandContext<CollectionSchemaObject>) this;
}

@SuppressWarnings("unchecked")
public CommandContext<TableSchemaObject> asTableContext() {
checkSchemaObjectType(TableSchemaObject.TYPE);
checkSchemaObjectType(SchemaObjectType.TABLE);
return (CommandContext<TableSchemaObject>) this;
}

@SuppressWarnings("unchecked")
public CommandContext<KeyspaceSchemaObject> asKeyspaceContext() {
checkSchemaObjectType(KeyspaceSchemaObject.TYPE);
checkSchemaObjectType(SchemaObjectType.KEYSPACE);
return (CommandContext<KeyspaceSchemaObject>) this;
}

@SuppressWarnings("unchecked")
public CommandContext<DatabaseSchemaObject> asDatabaseContext() {
checkSchemaObjectType(DatabaseSchemaObject.TYPE);
checkSchemaObjectType(SchemaObjectType.DATABASE);
return (CommandContext<DatabaseSchemaObject>) this;
}

private void checkSchemaObjectType(SchemaObject.SchemaObjectType expectedType) {
private void checkSchemaObjectType(SchemaObjectType expectedType) {
Preconditions.checkArgument(
schemaObject().type() == expectedType,
"SchemaObject type actual was %s expected was %s ",
schemaObject().type(),
expectedType);
}

@Override
public void addToMDC() {
loggingMDCContexts.forEach(LoggingMDCContext::addToMDC);
}

@Override
public void removeFromMDC() {
loggingMDCContexts.forEach(LoggingMDCContext::removeFromMDC);
}

/**
* NOTE: Not using AutoCloseable because it created a lot of linting warnings, we only want to
* close this in the request resource handler.
*/
public void close() throws Exception {
removeFromMDC();
}

/**
* Configure the BuilderSupplier with resources and config that will be used for all the {@link
* CommandContext} that will be created. Then called {@link
Expand Down Expand Up @@ -341,18 +377,21 @@ public CommandContext<SchemaT> build() {
Objects.requireNonNull(commandName, "commandName must not be null");
Objects.requireNonNull(requestContext, "requestContext must not be null");

return new CommandContext<>(
schemaObject,
embeddingProvider,
commandName,
requestContext,
jsonProcessingMetricsReporter,
cqlSessionCache,
commandConfig,
apiFeatures,
embeddingProviderFactory,
rerankingProviderFactory,
meterRegistry);
var context =
new CommandContext<>(
schemaObject,
embeddingProvider,
commandName,
requestContext,
jsonProcessingMetricsReporter,
cqlSessionCache,
commandConfig,
apiFeatures,
embeddingProviderFactory,
rerankingProviderFactory,
meterRegistry);
context.addToMDC();
return context;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
* The schema object a command can be called against.
*
* <p>Example: creteTable runs against the Keyspace , so target is the Keyspace aaron 13 - nove -
* 2024 - not using the {@link
* io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject.SchemaObjectType} because this
* also needs the SYSTEM value, and the schema object design prob needs improvement
* 2024 - not using the {@link io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObjectType}
* because this also needs the SYSTEM value, and the schema object design prob needs improvement
*/
public enum CommandTarget {
COLLECTION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import io.stargate.sgv2.jsonapi.config.OperationsConfig;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.exception.FilterException;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.tables.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.shredding.collections.DocumentId;
import io.stargate.sgv2.jsonapi.service.shredding.collections.JsonExtensionType;
import io.stargate.sgv2.jsonapi.util.JsonUtil;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition;
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause;
import io.stargate.sgv2.jsonapi.exception.SortException;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.tables.TableSchemaObject;
import io.stargate.sgv2.jsonapi.util.JsonUtil;
import java.util.Map;
import java.util.Objects;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.*;
import io.stargate.sgv2.jsonapi.api.model.command.table.definition.datatype.MapComponentDesc;
import io.stargate.sgv2.jsonapi.exception.FilterException;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.operation.filters.table.MapSetListFilterComponent;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTypeName;
import io.stargate.sgv2.jsonapi.service.schema.tables.TableSchemaObject;
import io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil;
import java.util.*;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause;
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression;
import io.stargate.sgv2.jsonapi.exception.SortException;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiColumnDef;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiColumnDefContainer;
import io.stargate.sgv2.jsonapi.service.schema.tables.TableSchemaObject;
import io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil;
import io.stargate.sgv2.jsonapi.util.JsonUtil;
import java.util.ArrayList;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.stargate.sgv2.jsonapi.api.v1;

import static io.stargate.sgv2.jsonapi.config.constants.DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD;
import static io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil.cqlIdentifierFromUserInput;

import io.micrometer.core.instrument.MeterRegistry;
import io.smallrye.mutiny.Uni;
Expand Down Expand Up @@ -31,13 +32,14 @@
import io.stargate.sgv2.jsonapi.exception.mappers.ThrowableCommandResultSupplier;
import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter;
import io.stargate.sgv2.jsonapi.service.cqldriver.CqlSessionCacheSupplier;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaCache;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition;
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider;
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProviderFactory;
import io.stargate.sgv2.jsonapi.service.processor.MeteredCommandProcessor;
import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProviderFactory;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObjectCacheSupplier;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObjectType;
import io.stargate.sgv2.jsonapi.service.schema.UnscopedSchemaObjectIdentifier;
import jakarta.inject.Inject;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotEmpty;
Expand All @@ -60,13 +62,16 @@
import org.eclipse.microprofile.openapi.annotations.security.SecurityRequirement;
import org.eclipse.microprofile.openapi.annotations.tags.Tag;
import org.jboss.resteasy.reactive.RestResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Path(CollectionResource.BASE_PATH)
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
@SecurityRequirement(name = OpenApiConstants.SecuritySchemes.TOKEN)
@Tag(ref = "Documents")
public class CollectionResource {
private static final Logger LOGGER = LoggerFactory.getLogger(CollectionResource.class);

public static final String BASE_PATH = GeneralResource.BASE_PATH + "/{keyspace}/{collection}";

Expand All @@ -75,20 +80,23 @@ public class CollectionResource {
// TODO remove apiFeatureConfig as a property after cleanup for how we get schema from cache
@Inject private FeaturesConfig apiFeatureConfig;
@Inject private RequestContext requestContext;
@Inject private SchemaCache schemaCache;

private final SchemaObjectCacheSupplier schemaObjectCacheSupplier;
private final CommandContext.BuilderSupplier contextBuilderSupplier;
private final EmbeddingProviderFactory embeddingProviderFactory;
private final MeteredCommandProcessor meteredCommandProcessor;

@Inject
public CollectionResource(
SchemaObjectCacheSupplier schemaObjectCacheSupplier,
MeteredCommandProcessor meteredCommandProcessor,
MeterRegistry meterRegistry,
JsonProcessingMetricsReporter jsonProcessingMetricsReporter,
CqlSessionCacheSupplier sessionCacheSupplier,
EmbeddingProviderFactory embeddingProviderFactory,
RerankingProviderFactory rerankingProviderFactory) {

this.schemaObjectCacheSupplier = schemaObjectCacheSupplier;
this.embeddingProviderFactory = embeddingProviderFactory;
this.meteredCommandProcessor = meteredCommandProcessor;

Expand Down Expand Up @@ -198,12 +206,15 @@ public Uni<RestResponse<CommandResult>> postCommand(
@NotNull @Valid CollectionCommand command,
@PathParam("keyspace") @NotEmpty String keyspace,
@PathParam("collection") @NotEmpty String collection) {
return schemaCache
.getSchemaObject(
requestContext,
keyspace,
collection,
CommandType.DDL.equals(command.commandName().getCommandType()))

var name =
new UnscopedSchemaObjectIdentifier.DefaultKeyspaceScopedName(
cqlIdentifierFromUserInput(keyspace), cqlIdentifierFromUserInput(collection));
var forceRefresh = CommandType.DDL.equals(command.commandName().getCommandType());

return schemaObjectCacheSupplier
.get()
.getTableBased(requestContext, name, requestContext.userAgent(), forceRefresh)
.onItemOrFailure()
.transformToUni(
(schemaObject, throwable) -> {
Expand All @@ -219,19 +230,17 @@ public Uni<RestResponse<CommandResult>> postCommand(
// otherwise use generic for now
return Uni.createFrom().item(new ThrowableCommandResultSupplier(error));
} else {
// TODO No need for the else clause here, simplify

// TODO: This needs to change, currently it is only checking if there is vectorize
// for the $vector column in a collection

VectorColumnDefinition vectorColDef = null;
if (schemaObject.type() == SchemaObject.SchemaObjectType.COLLECTION) {
if (schemaObject.type() == SchemaObjectType.COLLECTION) {
vectorColDef =
schemaObject
.vectorConfig()
.getColumnDefinition(VECTOR_EMBEDDING_TEXT_FIELD)
.orElse(null);
} else if (schemaObject.type() == SchemaObject.SchemaObjectType.TABLE) {
} else if (schemaObject.type() == SchemaObjectType.TABLE) {
vectorColDef =
schemaObject
.vectorConfig()
Expand Down Expand Up @@ -262,7 +271,20 @@ public Uni<RestResponse<CommandResult>> postCommand(
.withRequestContext(requestContext)
.build();

return meteredCommandProcessor.processCommand(commandContext, command);
return meteredCommandProcessor
.processCommand(commandContext, command)
.onTermination()
.invoke(
() -> {
try {
commandContext.close();
} catch (Exception e) {
LOGGER.error(
"Error closing the command context for requestContext={}",
requestContext,
e);
}
});
}
})
.map(commandResult -> commandResult.toRestResponse());
Expand Down
Loading
Loading