diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java index f3894b0e1..44377db09 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java @@ -182,6 +182,7 @@ public Object advanceEvaluation(UnknownContext context) throws CelEvaluationExce static Builder newBuilder() { return new AutoValue_CelRuntimeImpl.Builder() + .setFunctionBindings(ImmutableMap.of()) .setStandardFunctions(CelStandardFunctions.newBuilder().build()) .setContainer(CelContainer.newBuilder().build()) .setExtensionRegistry(ExtensionRegistry.getEmptyRegistry()); @@ -222,6 +223,8 @@ abstract static class Builder implements CelRuntimeBuilder { abstract ExtensionRegistry extensionRegistry(); + abstract ImmutableMap functionBindings(); + abstract ImmutableSet.Builder fileDescriptorsBuilder(); abstract ImmutableSet.Builder runtimeLibrariesBuilder(); @@ -442,6 +445,9 @@ public CelRuntime build() { DescriptorTypeResolver descriptorTypeResolver = DescriptorTypeResolver.create(combinedTypeProvider); TypeFunction typeFunction = TypeFunction.create(descriptorTypeResolver); + + mutableFunctionBindings.putAll(functionBindings()); + for (CelFunctionBinding binding : typeFunction.newFunctionBindings(options(), runtimeEquality)) { mutableFunctionBindings.put(binding.getOverloadId(), binding); diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 097282f6b..569d7372d 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -131,7 +131,11 @@ java_library( "PlannerInterpreterTest.java", ], deps = [ + "//common:cel_ast", + "//common:compiler_common", + "//common:container", "//common:options", + "//common/types:type_providers", "//extensions", "//runtime", "//runtime:runtime_planner_impl", diff --git a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java index 498d9f797..337061afa 100644 --- a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java @@ -14,8 +14,13 @@ package dev.cel.runtime; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelContainer; import dev.cel.common.CelOptions; +import dev.cel.common.CelValidationException; +import dev.cel.common.types.CelTypeProvider; import dev.cel.extensions.CelExtensions; import dev.cel.testing.BaseInterpreterTest; import org.junit.runner.RunWith; @@ -24,6 +29,8 @@ @RunWith(TestParameterInjector.class) public class PlannerInterpreterTest extends BaseInterpreterTest { + @TestParameter boolean isParseOnly; + @Override protected CelRuntimeBuilder newBaseRuntimeBuilder(CelOptions celOptions) { return CelRuntimeImpl.newBuilder() @@ -34,6 +41,36 @@ protected CelRuntimeBuilder newBaseRuntimeBuilder(CelOptions celOptions) { .addFileTypes(TEST_FILE_DESCRIPTORS); } + @Override + protected void setContainer(CelContainer container) { + super.setContainer(container); + this.celRuntime = this.celRuntime.toRuntimeBuilder().setContainer(container).build(); + } + + @Override + protected CelAbstractSyntaxTree prepareTest(CelTypeProvider typeProvider) { + super.prepareCompiler(typeProvider); + + CelAbstractSyntaxTree ast; + try { + ast = celCompiler.parse(source, testSourceDescription()).getAst(); + } catch (CelValidationException e) { + printTestValidationError(e); + return null; + } + + if (isParseOnly) { + return ast; + } + + try { + return celCompiler.check(ast).getAst(); + } catch (CelValidationException e) { + printTestValidationError(e); + return null; + } + } + @Override public void unknownField() { // TODO: Unknown support not implemented yet @@ -45,4 +82,25 @@ public void unknownResultSet() { // TODO: Unknown support not implemented yet skipBaselineVerification(); } + + @Override + public void optional() { + if (isParseOnly) { + // TODO: Fix for parsed-only mode. + skipBaselineVerification(); + } else { + super.optional(); + } + } + + @Override + public void optional_errors() { + if (isParseOnly) { + // Parsed-only evaluation contains function name in the + // error message instead of the function overload. + skipBaselineVerification(); + } else { + super.optional_errors(); + } + } } diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index 5ecfd37f2..f2480a034 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -93,6 +93,7 @@ java_library( "//runtime:late_function_binding", "//runtime:unknown_attributes", "@cel_spec//proto/cel/expr:checked_java_proto", + "@cel_spec//proto/cel/expr:syntax_java_proto", "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@maven//:com_google_errorprone_error_prone_annotations", diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index f1ba69c6f..bc67e8218 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -20,6 +20,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import dev.cel.expr.CheckedExpr; +import dev.cel.expr.ParsedExpr; import dev.cel.expr.Type; import com.google.common.base.Ascii; import com.google.common.collect.ImmutableList; @@ -87,6 +88,7 @@ import java.time.Duration; import java.time.Instant; import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -115,8 +117,8 @@ public abstract class BaseInterpreterTest extends CelBaselineTestCase { .enableOptionalSyntax(true) .comprehensionMaxIterations(1_000) .build(); - private CelRuntime celRuntime; - protected CelOptions celOptions; + protected CelRuntime celRuntime; + private CelOptions celOptions; protected BaseInterpreterTest() { this.celOptions = BASE_CEL_OPTIONS; @@ -155,6 +157,10 @@ protected void prepareCompiler(CelTypeProvider typeProvider) { .build(); } + protected void setContainer(CelContainer container) { + this.container = container; + } + private CelAbstractSyntaxTree compileTestCase() { CelAbstractSyntaxTree ast = prepareTest(TEST_FILE_DESCRIPTORS); if (ast == null) { @@ -352,7 +358,7 @@ public void quantifiers() { @Test public void arithmTimestamp() { - container = CelContainer.ofName(Type.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(Type.getDescriptor().getFile().getPackage())); declareVariable("ts1", SimpleType.TIMESTAMP); declareVariable("ts2", SimpleType.TIMESTAMP); declareVariable("d1", SimpleType.DURATION); @@ -379,7 +385,7 @@ public void arithmTimestamp() { @Test public void arithmDuration() { - container = CelContainer.ofName(Type.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(Type.getDescriptor().getFile().getPackage())); declareVariable("d1", SimpleType.DURATION); declareVariable("d2", SimpleType.DURATION); declareVariable("d3", SimpleType.DURATION); @@ -536,14 +542,14 @@ public void messages() throws Exception { runTest(ImmutableMap.of("single_nested_message", nestedMessage.getSingleNestedMessage())); source = "TestAllTypes{single_int64: 1, single_sfixed64: 2, single_int32: 2}.single_int32 == 2"; - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); runTest(); } @Test public void messages_error() { source = "TestAllTypes{single_int32_wrapper: 12345678900}"; - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); runTest(); source = "TestAllTypes{}.map_string_string.a"; @@ -569,12 +575,12 @@ public void optional_errors() { @Test public void containers() { - container = + setContainer( CelContainer.newBuilder() .setName("dev.cel.testing.testdata.proto3.StandaloneGlobalEnum") .addAlias("test_alias", TestAllTypes.getDescriptor().getFile().getPackage()) .addAbbreviations("cel.expr.conformance.proto2", "cel.expr.conformance.proto3") - .build(); + .build()); source = "test_alias.TestAllTypes{} == cel.expr.conformance.proto3.TestAllTypes{}"; runTest(); @@ -621,7 +627,7 @@ public void duration() throws Exception { java.time.Duration d1010 = java.time.Duration.ofSeconds(10, 10); java.time.Duration d1009 = java.time.Duration.ofSeconds(10, 9); java.time.Duration d0910 = java.time.Duration.ofSeconds(9, 10); - container = CelContainer.ofName(Type.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(Type.getDescriptor().getFile().getPackage())); source = "d1 < d2"; runTest(extend(ImmutableMap.of("d1", d1010), ImmutableMap.of("d2", d1009))); @@ -659,7 +665,7 @@ public void timestamp() throws Exception { Instant ts1010 = Instant.ofEpochSecond(10, 10); Instant ts1009 = Instant.ofEpochSecond(10, 9); Instant ts0910 = Instant.ofEpochSecond(9, 10); - container = CelContainer.ofName(Type.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(Type.getDescriptor().getFile().getPackage())); source = "t1 < t2"; runTest(extend(ImmutableMap.of("t1", ts1010), ImmutableMap.of("t2", ts1009))); @@ -697,7 +703,7 @@ public void packUnpackAny() { skipBaselineVerification(); return; } - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); declareVariable("any", SimpleType.ANY); declareVariable("d", SimpleType.DURATION); declareVariable( @@ -745,7 +751,7 @@ public void packUnpackAny() { public void nestedEnums() { TestAllTypes nestedEnum = TestAllTypes.newBuilder().setSingleNestedEnum(NestedEnum.BAR).build(); declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); source = "x.single_nested_enum == TestAllTypes.NestedEnum.BAR"; runTest(ImmutableMap.of("x", nestedEnum)); @@ -769,7 +775,7 @@ public void globalEnums() { public void lists() throws Exception { declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); declareVariable("y", SimpleType.INT); - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); source = "([1, 2, 3] + x.repeated_int32)[3] == 4"; runTest(ImmutableMap.of("x", TestAllTypes.newBuilder().addRepeatedInt32(4).build())); @@ -809,7 +815,7 @@ public void lists_error() { @Test public void maps() throws Exception { declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); source = "{1: 2, 3: 4}[3] == 4"; runTest(); @@ -869,19 +875,19 @@ public void comprehension() throws Exception { runTest(); declareVariable("com.x", SimpleType.INT); - container = CelContainer.ofName("com"); + setContainer(CelContainer.ofName("com")); source = "[0].exists(x, x == 0)"; runTest(ImmutableMap.of("com.x", 1)); clearAllDeclarations(); declareVariable("cel.example.y", SimpleType.INT); - container = CelContainer.ofName("cel.example"); + setContainer(CelContainer.ofName("cel.example")); source = "[{'z': 0}].exists(y, y.z == 0)"; runTest(ImmutableMap.of("cel.example.y", ImmutableMap.of("z", 1))); clearAllDeclarations(); declareVariable("y.z", SimpleType.INT); - container = CelContainer.ofName("y"); + setContainer(CelContainer.ofName("y")); source = "[{'z': 0}].exists(y, y.z == 0 && .y.z == 1)"; runTest(ImmutableMap.of("y.z", 1)); @@ -942,8 +948,14 @@ public void namespacedFunctions() { ImmutableList.of(SimpleType.INT, SimpleType.INT), SimpleType.INT)); addFunctionBinding( - CelFunctionBinding.from("ns_func_overload", String.class, s -> (long) s.length()), - CelFunctionBinding.from("ns_member_overload", Long.class, Long.class, Long::sum)); + CelFunctionBinding.fromOverloads( + "ns.func", + CelFunctionBinding.from("ns_func_overload", String.class, s -> (long) s.length()))); + addFunctionBinding( + CelFunctionBinding.fromOverloads( + "member", + CelFunctionBinding.from("ns_member_overload", Long.class, Long.class, Long::sum))); + source = "ns.func('hello')"; runTest(); @@ -965,7 +977,7 @@ public void namespacedFunctions() { source = "[1, 2].map(x, x * ns.func('test'))"; runTest(); - container = CelContainer.ofName("ns"); + setContainer(CelContainer.ofName("ns")); // Call with the container set as the function's namespace source = "ns.func('hello')"; runTest(); @@ -979,12 +991,12 @@ public void namespacedFunctions() { @Test public void namespacedVariables() { - container = CelContainer.ofName("ns"); + setContainer(CelContainer.ofName("ns")); declareVariable("ns.x", SimpleType.INT); source = "x"; runTest(ImmutableMap.of("ns.x", 2)); - container = CelContainer.ofName("dev.cel.testing.testdata.proto3"); + setContainer(CelContainer.ofName("dev.cel.testing.testdata.proto3")); CelType messageType = StructTypeReference.create("cel.expr.conformance.proto3.TestAllTypes"); declareVariable("dev.cel.testing.testdata.proto3.msgVar", messageType); source = "msgVar.single_int32"; @@ -1002,7 +1014,7 @@ public void durationFunctions() { Duration d1 = Duration.ofSeconds(totalSeconds, nanos); Duration d2 = Duration.ofSeconds(-totalSeconds, -nanos); - container = CelContainer.ofName(Type.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(Type.getDescriptor().getFile().getPackage())); source = "d1.getHours()"; runTest(ImmutableMap.of("d1", d1)); @@ -1034,7 +1046,7 @@ public void durationFunctions() { @Test public void timestampFunctions() { declareVariable("ts1", SimpleType.TIMESTAMP); - container = CelContainer.ofName(Type.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(Type.getDescriptor().getFile().getPackage())); Instant ts1 = Instant.ofEpochSecond(1, 11000000); Instant ts2 = Instant.ofEpochSecond(-1, 0); @@ -1163,7 +1175,7 @@ public void timestampFunctions() { @Test public void unknownField() { - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); // Unknown field is accessed. @@ -1195,7 +1207,7 @@ public void unknownField() { @Test public void unknownResultSet() { - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); TestAllTypes message = TestAllTypes.newBuilder() @@ -1397,7 +1409,7 @@ public void unknownResultSet() { @Test public void timeConversions() { - container = CelContainer.ofName(Type.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(Type.getDescriptor().getFile().getPackage())); declareVariable("t1", SimpleType.TIMESTAMP); source = "timestamp(\"1972-01-01T10:00:20.021-05:00\")"; @@ -1430,7 +1442,7 @@ public void timeConversions_error() { @Test public void sizeTests() { - container = CelContainer.ofName(Type.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(Type.getDescriptor().getFile().getPackage())); declareVariable("str", SimpleType.STRING); declareVariable("b", SimpleType.BYTES); @@ -1801,7 +1813,7 @@ public void dyn_error() { @Test public void jsonValueTypes() { - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); // JSON bool selection. @@ -1918,7 +1930,7 @@ public void jsonConversions() { @Test public void typeComparisons() { - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); // Test numeric types. source = @@ -2078,7 +2090,7 @@ public void wrappers() throws Exception { declareVariable("string_list", ListType.create(SimpleType.STRING)); declareVariable("bytes_list", ListType.create(SimpleType.BYTES)); - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFullName()); + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFullName())); source = "TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && " + "TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && " @@ -2266,7 +2278,7 @@ public void dynamicMessage_adapted() throws Exception { @Test public void dynamicMessage_dynamicDescriptor() throws Exception { - container = CelContainer.ofName("dev.cel.testing.testdata.serialized.proto3"); + setContainer(CelContainer.ofName("dev.cel.testing.testdata.serialized.proto3")); source = "TestAllTypes {}"; assertThat(runTest()).isInstanceOf(DynamicMessage.class); @@ -2363,8 +2375,10 @@ public void dynamicMessage_dynamicDescriptor() throws Exception { StructTypeReference.create(TEST_ALL_TYPE_DYNAMIC_DESCRIPTOR.getFullName())), SimpleType.BOOL)); addFunctionBinding( - CelFunctionBinding.from("f_msg_generated", TestAllTypes.class, unused -> true), - CelFunctionBinding.from("f_msg_dynamic", DynamicMessage.class, unused -> true)); + CelFunctionBinding.fromOverloads( + "f_msg", + CelFunctionBinding.from("f_msg_generated", TestAllTypes.class, unused -> true), + CelFunctionBinding.from("f_msg_dynamic", DynamicMessage.class, unused -> true))); input = ImmutableMap.of( "dynamic_msg", dynamicMessageBuilder.build(), @@ -2410,14 +2424,38 @@ public void lateBoundFunctions() throws Exception { assertThat(recordedValues.getRecordedValues()).containsExactly("foo", "bar"); } + @Test + public void jsonFieldNames() throws Exception { + this.celOptions = celOptions.toBuilder().enableJsonFieldNames(true).build(); + this.celRuntime = newBaseRuntimeBuilder(celOptions).build(); + + TestAllTypes message = TestAllTypes.newBuilder().setSingleInt32(42).build(); + declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); + + source = "x.singleInt32 == 42"; + assertThat(runTest(ImmutableMap.of("x", message))).isEqualTo(true); + + source = "TestAllTypes{singleInt32: 42}.singleInt32 == 42"; + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); + assertThat(runTest()).isEqualTo(true); + + skipBaselineVerification(); + } + /** * Checks that the CheckedExpr produced by CelCompiler is equal to the one reproduced by the * native CelAbstractSyntaxTree */ private static void assertAstRoundTrip(CelAbstractSyntaxTree ast) { - CheckedExpr checkedExpr = CelProtoAbstractSyntaxTree.fromCelAst(ast).toCheckedExpr(); - CelProtoAbstractSyntaxTree protoAst = CelProtoAbstractSyntaxTree.fromCelAst(ast); - assertThat(checkedExpr).isEqualTo(protoAst.toCheckedExpr()); + if (ast.isChecked()) { + CheckedExpr checkedExpr = CelProtoAbstractSyntaxTree.fromCelAst(ast).toCheckedExpr(); + CelProtoAbstractSyntaxTree protoAst = CelProtoAbstractSyntaxTree.fromCelAst(ast); + assertThat(checkedExpr).isEqualTo(protoAst.toCheckedExpr()); + } else { + ParsedExpr parsedExpr = CelProtoAbstractSyntaxTree.fromCelAst(ast).toParsedExpr(); + CelProtoAbstractSyntaxTree protoAst = CelProtoAbstractSyntaxTree.fromCelAst(ast); + assertThat(parsedExpr).isEqualTo(protoAst.toParsedExpr()); + } } private static String readResourceContent(String path) throws IOException { @@ -2505,6 +2543,10 @@ private static CelVariableResolver extend(CelVariableResolver primary, Map functionBindings) { celRuntime = celRuntime.toRuntimeBuilder().addFunctionBindings(functionBindings).build(); } @@ -2528,22 +2570,4 @@ private static Descriptor getDeserializedTestAllTypeDescriptor() { throw new RuntimeException("Error loading TestAllTypes descriptor", e); } } - - @Test - public void jsonFieldNames() throws Exception { - this.celOptions = celOptions.toBuilder().enableJsonFieldNames(true).build(); - this.celRuntime = newBaseRuntimeBuilder(celOptions).build(); - - TestAllTypes message = TestAllTypes.newBuilder().setSingleInt32(42).build(); - declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); - - source = "x.singleInt32 == 42"; - assertThat(runTest(ImmutableMap.of("x", message))).isEqualTo(true); - - source = "TestAllTypes{singleInt32: 42}.singleInt32 == 42"; - container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); - assertThat(runTest()).isEqualTo(true); - - skipBaselineVerification(); - } }