Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 68 additions & 45 deletions extensions/src/main/java/dev/cel/extensions/CelOptionalLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static dev.cel.extensions.CelOptionalLibrary.Function.FIRST;
import static dev.cel.extensions.CelOptionalLibrary.Function.HAS_VALUE;
import static dev.cel.extensions.CelOptionalLibrary.Function.LAST;
import static dev.cel.extensions.CelOptionalLibrary.Function.OPTIONAL_NONE;
import static dev.cel.extensions.CelOptionalLibrary.Function.OPTIONAL_OF;
import static dev.cel.extensions.CelOptionalLibrary.Function.OPTIONAL_OF_NON_ZERO_VALUE;
import static dev.cel.extensions.CelOptionalLibrary.Function.OPTIONAL_UNWRAP;
import static dev.cel.extensions.CelOptionalLibrary.Function.VALUE;
import static dev.cel.runtime.CelFunctionBinding.from;
import static dev.cel.runtime.CelFunctionBinding.fromOverloads;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -46,7 +56,6 @@
import dev.cel.parser.CelMacro;
import dev.cel.parser.CelMacroExprFactory;
import dev.cel.parser.CelParserBuilder;
import dev.cel.runtime.CelFunctionBinding;
import dev.cel.runtime.CelInternalRuntimeLibrary;
import dev.cel.runtime.CelRuntimeBuilder;
import dev.cel.runtime.RuntimeEquality;
Expand Down Expand Up @@ -97,26 +106,26 @@ public String getFunction() {
0,
ImmutableSet.of(
CelFunctionDecl.newFunctionDeclaration(
Function.OPTIONAL_OF.getFunction(),
OPTIONAL_OF.getFunction(),
CelOverloadDecl.newGlobalOverload(
"optional_of", optionalTypeV, paramTypeV)),
CelFunctionDecl.newFunctionDeclaration(
Function.OPTIONAL_OF_NON_ZERO_VALUE.getFunction(),
OPTIONAL_OF_NON_ZERO_VALUE.getFunction(),
CelOverloadDecl.newGlobalOverload(
"optional_ofNonZeroValue", optionalTypeV, paramTypeV)),
CelFunctionDecl.newFunctionDeclaration(
Function.OPTIONAL_NONE.getFunction(),
OPTIONAL_NONE.getFunction(),
CelOverloadDecl.newGlobalOverload("optional_none", optionalTypeV)),
CelFunctionDecl.newFunctionDeclaration(
Function.VALUE.getFunction(),
VALUE.getFunction(),
CelOverloadDecl.newMemberOverload(
"optional_value", paramTypeV, optionalTypeV)),
CelFunctionDecl.newFunctionDeclaration(
Function.HAS_VALUE.getFunction(),
HAS_VALUE.getFunction(),
CelOverloadDecl.newMemberOverload(
"optional_hasValue", SimpleType.BOOL, optionalTypeV)),
CelFunctionDecl.newFunctionDeclaration(
Function.OPTIONAL_UNWRAP.getFunction(),
OPTIONAL_UNWRAP.getFunction(),
CelOverloadDecl.newGlobalOverload(
"optional_unwrap_list", listTypeV, ListType.create(optionalTypeV))),
// Note: Implementation of "or" and "orValue" are special-cased inside the
Expand Down Expand Up @@ -193,15 +202,15 @@ public String getFunction() {
.addAll(version1.functions)
.add(
CelFunctionDecl.newFunctionDeclaration(
Function.FIRST.functionName,
FIRST.functionName,
CelOverloadDecl.newMemberOverload(
"optional_list_first",
"Return the first value in a list if present, otherwise"
+ " optional.none()",
optionalTypeV,
listTypeV)),
CelFunctionDecl.newFunctionDeclaration(
Function.LAST.functionName,
LAST.functionName,
CelOverloadDecl.newMemberOverload(
"optional_list_last",
"Return the last value in a list if present, otherwise"
Expand Down Expand Up @@ -295,48 +304,65 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
public void setRuntimeOptions(
CelRuntimeBuilder runtimeBuilder, RuntimeEquality runtimeEquality, CelOptions celOptions) {
runtimeBuilder.addFunctionBindings(
CelFunctionBinding.from("optional_of", Object.class, Optional::of),
CelFunctionBinding.from(
"optional_ofNonZeroValue",
Object.class,
val -> {
if (isZeroValue(val)) {
return Optional.empty();
}
return Optional.of(val);
}),
CelFunctionBinding.from(
"optional_unwrap_list", Collection.class, CelOptionalLibrary::elideOptionalCollection),
CelFunctionBinding.from("optional_none", ImmutableList.of(), val -> Optional.empty()),
CelFunctionBinding.from("optional_value", Object.class, val -> ((Optional<?>) val).get()),
CelFunctionBinding.from(
"optional_hasValue", Object.class, val -> ((Optional<?>) val).isPresent()),
CelFunctionBinding.from(
fromOverloads(OPTIONAL_OF.getFunction(), from("optional_of", Object.class, Optional::of)));
runtimeBuilder.addFunctionBindings(
fromOverloads(
OPTIONAL_OF_NON_ZERO_VALUE.getFunction(),
from(
"optional_ofNonZeroValue",
Object.class,
val -> {
if (isZeroValue(val)) {
return Optional.empty();
}
return Optional.of(val);
})));
runtimeBuilder.addFunctionBindings(
fromOverloads(
OPTIONAL_UNWRAP.getFunction(),
from(
"optional_unwrap_list",
Collection.class,
CelOptionalLibrary::elideOptionalCollection)));
runtimeBuilder.addFunctionBindings(
fromOverloads(
OPTIONAL_NONE.getFunction(),
from("optional_none", ImmutableList.of(), val -> Optional.empty())));
runtimeBuilder.addFunctionBindings(
fromOverloads(
VALUE.getFunction(),
from("optional_value", Object.class, val -> ((Optional<?>) val).get())));
runtimeBuilder.addFunctionBindings(
fromOverloads(
HAS_VALUE.getFunction(),
from("optional_hasValue", Object.class, val -> ((Optional<?>) val).isPresent())));

runtimeBuilder.addFunctionBindings(
from(
"select_optional_field", // This only handles map selection. Proto selection is
// special cased inside the interpreter.
Map.class,
String.class,
runtimeEquality::findInMap),
CelFunctionBinding.from(
"map_optindex_optional_value", Map.class, Object.class, runtimeEquality::findInMap),
CelFunctionBinding.from(
from("map_optindex_optional_value", Map.class, Object.class, runtimeEquality::findInMap),
from(
"optional_map_optindex_optional_value",
Optional.class,
Object.class,
(Optional optionalMap, Object key) ->
indexOptionalMap(optionalMap, key, runtimeEquality)),
CelFunctionBinding.from(
from(
"optional_map_index_value",
Optional.class,
Object.class,
(Optional optionalMap, Object key) ->
indexOptionalMap(optionalMap, key, runtimeEquality)),
CelFunctionBinding.from(
from(
"optional_list_index_int",
Optional.class,
Long.class,
CelOptionalLibrary::indexOptionalList),
CelFunctionBinding.from(
from(
"list_optindex_optional_int",
List.class,
Long.class,
Expand All @@ -347,18 +373,16 @@ public void setRuntimeOptions(
}
return Optional.of(list.get(castIndex));
}),
CelFunctionBinding.from(
from(
"optional_list_optindex_optional_int",
Optional.class,
Long.class,
CelOptionalLibrary::indexOptionalList));

if (version >= 2) {
runtimeBuilder.addFunctionBindings(
CelFunctionBinding.from(
"optional_list_first", Collection.class, CelOptionalLibrary::listOptionalFirst),
CelFunctionBinding.from(
"optional_list_last", Collection.class, CelOptionalLibrary::listOptionalLast));
from("optional_list_first", Collection.class, CelOptionalLibrary::listOptionalFirst),
from("optional_list_last", Collection.class, CelOptionalLibrary::listOptionalLast));
}
}

Expand Down Expand Up @@ -425,19 +449,18 @@ private static Optional<CelExpr> expandOptMap(
return Optional.of(
exprFactory.newGlobalCall(
Operator.CONDITIONAL.getFunction(),
exprFactory.newReceiverCall(Function.HAS_VALUE.getFunction(), target),
exprFactory.newReceiverCall(HAS_VALUE.getFunction(), target),
exprFactory.newGlobalCall(
Function.OPTIONAL_OF.getFunction(),
OPTIONAL_OF.getFunction(),
exprFactory.fold(
UNUSED_ITER_VAR,
exprFactory.newList(),
varName,
exprFactory.newReceiverCall(
Function.VALUE.getFunction(), exprFactory.copy(target)),
exprFactory.newReceiverCall(VALUE.getFunction(), exprFactory.copy(target)),
exprFactory.newBoolLiteral(true),
exprFactory.newIdentifier(varName),
mapExpr)),
exprFactory.newGlobalCall(Function.OPTIONAL_NONE.getFunction())));
exprFactory.newGlobalCall(OPTIONAL_NONE.getFunction())));
}

private static Optional<CelExpr> expandOptFlatMap(
Expand All @@ -460,16 +483,16 @@ private static Optional<CelExpr> expandOptFlatMap(
return Optional.of(
exprFactory.newGlobalCall(
Operator.CONDITIONAL.getFunction(),
exprFactory.newReceiverCall(Function.HAS_VALUE.getFunction(), target),
exprFactory.newReceiverCall(HAS_VALUE.getFunction(), target),
exprFactory.fold(
UNUSED_ITER_VAR,
exprFactory.newList(),
varName,
exprFactory.newReceiverCall(Function.VALUE.getFunction(), exprFactory.copy(target)),
exprFactory.newReceiverCall(VALUE.getFunction(), exprFactory.copy(target)),
exprFactory.newBoolLiteral(true),
exprFactory.newIdentifier(varName),
mapExpr),
exprFactory.newGlobalCall(Function.OPTIONAL_NONE.getFunction())));
exprFactory.newGlobalCall(OPTIONAL_NONE.getFunction())));
}

private static Object indexOptionalMap(
Expand Down
6 changes: 6 additions & 0 deletions runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ public Object advanceEvaluation(UnknownContext context) throws CelEvaluationExce

static Builder newBuilder() {
return new AutoValue_CelRuntimeImpl.Builder()
.setFunctionBindings(ImmutableMap.of())
.setStandardFunctions(CelStandardFunctions.newBuilder().build())
.setContainer(CelContainer.newBuilder().build())
.setExtensionRegistry(ExtensionRegistry.getEmptyRegistry());
Expand Down Expand Up @@ -222,6 +223,8 @@ abstract static class Builder implements CelRuntimeBuilder {

abstract ExtensionRegistry extensionRegistry();

abstract ImmutableMap<String, CelFunctionBinding> functionBindings();

abstract ImmutableSet.Builder<Descriptors.FileDescriptor> fileDescriptorsBuilder();

abstract ImmutableSet.Builder<CelRuntimeLibrary> runtimeLibrariesBuilder();
Expand Down Expand Up @@ -442,6 +445,9 @@ public CelRuntime build() {
DescriptorTypeResolver descriptorTypeResolver =
DescriptorTypeResolver.create(combinedTypeProvider);
TypeFunction typeFunction = TypeFunction.create(descriptorTypeResolver);

mutableFunctionBindings.putAll(functionBindings());

for (CelFunctionBinding binding :
typeFunction.newFunctionBindings(options(), runtimeEquality)) {
mutableFunctionBindings.put(binding.getOverloadId(), binding);
Expand Down
1 change: 1 addition & 0 deletions runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ java_library(
"//runtime:interpretable",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
"@maven//:org_jspecify_jspecify",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import dev.cel.common.types.CelType;
import dev.cel.common.types.CelTypeProvider;
import dev.cel.common.types.EnumType;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.TypeType;
import dev.cel.common.values.CelValue;
import dev.cel.common.values.CelValueConverter;
import dev.cel.runtime.GlobalResolver;
import java.util.NoSuchElementException;
import org.jspecify.annotations.Nullable;

@Immutable
final class NamespacedAttribute implements Attribute {
Expand All @@ -34,6 +36,14 @@ final class NamespacedAttribute implements Attribute {
private final CelValueConverter celValueConverter;
private final CelTypeProvider typeProvider;

ImmutableList<Qualifier> qualifiers() {
return qualifiers;
}

ImmutableSet<String> candidateVariableNames() {
return namespacedNames;
}

@Override
public Object resolve(GlobalResolver ctx, ExecutionFrame frame) {
GlobalResolver inputVars = ctx;
Expand All @@ -59,41 +69,62 @@ public Object resolve(GlobalResolver ctx, ExecutionFrame frame) {
}
}

CelType type = typeProvider.findType(name).orElse(null);
if (type != null) {
if (qualifiers.isEmpty()) {
// Resolution of a fully qualified type name: foo.bar.baz
return TypeType.create(type);
} else {
// This is potentially a fully qualified reference to an enum value
if (type instanceof EnumType && qualifiers.size() == 1) {
EnumType enumType = (EnumType) type;
String strQualifier = (String) qualifiers.get(0).value();
return enumType
.findNumberByName(strQualifier)
.orElseThrow(
() ->
new NoSuchElementException(
String.format(
"Field %s was not found on enum %s",
enumType.name(), strQualifier)));
}
}

throw new IllegalStateException(
"Unexpected type resolution when there were remaining qualifiers: " + type.name());
// Attempt to resolve the qualify type name if the name is not a variable identifier
value = findIdent(name);
if (value != null) {
return value;
}
}

return MissingAttribute.newMissingAttribute(namespacedNames);
}

ImmutableList<Qualifier> qualifiers() {
return qualifiers;
private @Nullable Object findIdent(String name) {
CelType type = typeProvider.findType(name).orElse(null);
// If the name resolves directly, this is a fully qualified type name
// (ex: 'int' or 'google.protobuf.Timestamp')
if (type != null) {
if (qualifiers.isEmpty()) {
// Resolution of a fully qualified type name: foo.bar.baz
if (type instanceof TypeType) {
// Coalesce all type(foo) "type" into a sentinel runtime type to allow for
// erasure based type comparisons
return TypeType.create(SimpleType.DYN);
}

return TypeType.create(type);
}

throw new IllegalStateException(
"Unexpected type resolution when there were remaining qualifiers: " + type.name());
}

// The name itself could be a fully qualified reference to an enum value
// (e.g: my.enum_type.BAR)
int lastDotIndex = name.lastIndexOf('.');
if (lastDotIndex > 0) {
String enumTypeName = name.substring(0, lastDotIndex);
String enumValueQualifier = name.substring(lastDotIndex + 1);

return typeProvider
.findType(enumTypeName)
.filter(EnumType.class::isInstance)
.map(EnumType.class::cast)
.map(enumType -> getEnumValue(enumType, enumValueQualifier))
.orElse(null);
}

return null;
}

ImmutableSet<String> candidateVariableNames() {
return namespacedNames;
private static Long getEnumValue(EnumType enumType, String field) {
return enumType
.findNumberByName(field)
.map(Integer::longValue)
.orElseThrow(
() ->
new NoSuchElementException(
String.format("Field %s was not found on enum %s", enumType.name(), field)));
}

private GlobalResolver unwrapToNonLocal(GlobalResolver resolver) {
Expand Down
4 changes: 4 additions & 0 deletions runtime/src/test/java/dev/cel/runtime/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ java_library(
"PlannerInterpreterTest.java",
],
deps = [
"//common:cel_ast",
"//common:compiler_common",
"//common:container",
"//common:options",
"//common/types:type_providers",
"//extensions",
"//runtime",
"//runtime:runtime_planner_impl",
Expand Down
Loading
Loading