Skip to content

Commit bb10b79

Browse files
committed
Instantiate AgenticPolicyCompiler environment from YAML definitions
1 parent 182d7c2 commit bb10b79

18 files changed

+283
-214
lines changed

bundle/src/main/java/dev/cel/bundle/CelEnvironment.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import dev.cel.common.types.OptionalType;
4444
import dev.cel.common.types.SimpleType;
4545
import dev.cel.common.types.TypeParamType;
46+
import dev.cel.common.types.TypeType;
4647
import dev.cel.compiler.CelCompiler;
4748
import dev.cel.compiler.CelCompilerBuilder;
4849
import dev.cel.compiler.CelCompilerLibrary;
@@ -69,9 +70,10 @@ public abstract class CelEnvironment {
6970
"math", CanonicalCelExtension.MATH,
7071
"optional", CanonicalCelExtension.OPTIONAL,
7172
"protos", CanonicalCelExtension.PROTOS,
73+
"regex", CanonicalCelExtension.REGEX,
7274
"sets", CanonicalCelExtension.SETS,
7375
"strings", CanonicalCelExtension.STRINGS,
74-
"comprehensions", CanonicalCelExtension.COMPREHENSIONS);
76+
"two-var-comprehensions", CanonicalCelExtension.COMPREHENSIONS);
7577

7678
/** Environment source in textual format (ex: textproto, YAML). */
7779
public abstract Optional<Source> source();
@@ -82,7 +84,7 @@ public abstract class CelEnvironment {
8284
/**
8385
* Container, which captures default namespace and aliases for value resolution.
8486
*/
85-
public abstract CelContainer container();
87+
public abstract Optional<CelContainer> container();
8688

8789
/**
8890
* An optional description of the environment (example: location of the file containing the config
@@ -186,7 +188,6 @@ public static Builder newBuilder() {
186188
return new AutoValue_CelEnvironment.Builder()
187189
.setName("")
188190
.setDescription("")
189-
.setContainer(CelContainer.ofName(""))
190191
.setVariables(ImmutableSet.of())
191192
.setFunctions(ImmutableSet.of());
192193
}
@@ -199,7 +200,6 @@ public CelCompiler extend(CelCompiler celCompiler, CelOptions celOptions)
199200
CelCompilerBuilder compilerBuilder =
200201
celCompiler
201202
.toCompilerBuilder()
202-
.setContainer(container())
203203
.setTypeProvider(celTypeProvider)
204204
.addVarDeclarations(
205205
variables().stream()
@@ -210,6 +210,9 @@ public CelCompiler extend(CelCompiler celCompiler, CelOptions celOptions)
210210
.map(f -> f.toCelFunctionDecl(celTypeProvider))
211211
.collect(toImmutableList()));
212212

213+
214+
container().ifPresent(compilerBuilder::setContainer);
215+
213216
addAllCompilerExtensions(compilerBuilder, celOptions);
214217

215218
applyStandardLibrarySubset(compilerBuilder);
@@ -349,6 +352,9 @@ public abstract static class VariableDecl {
349352
/** The type of the variable. */
350353
public abstract TypeDecl type();
351354

355+
/** Description of the variable. */
356+
public abstract Optional<String> description();
357+
352358
/** Builder for {@link VariableDecl}. */
353359
@AutoValue.Builder
354360
public abstract static class Builder implements RequiredFieldsChecker {
@@ -361,6 +367,8 @@ public abstract static class Builder implements RequiredFieldsChecker {
361367

362368
public abstract VariableDecl.Builder setType(TypeDecl typeDecl);
363369

370+
public abstract VariableDecl.Builder setDescription(String name);
371+
364372
@Override
365373
public ImmutableList<RequiredField> requiredFields() {
366374
return ImmutableList.of(
@@ -600,6 +608,9 @@ public CelType toCelType(CelTypeProvider celTypeProvider) {
600608
CelType keyType = params().get(0).toCelType(celTypeProvider);
601609
CelType valueType = params().get(1).toCelType(celTypeProvider);
602610
return MapType.create(keyType, valueType);
611+
case "type":
612+
checkState(params().size() == 1, "Expected 1 parameter for type, got " + params().size());
613+
return TypeType.create(params().get(0).toCelType(celTypeProvider));
603614
default:
604615
if (isTypeParam()) {
605616
return TypeParamType.create(name());
@@ -734,10 +745,14 @@ enum CanonicalCelExtension {
734745
SETS(
735746
(options, version) -> CelExtensions.sets(options),
736747
(options, version) -> CelExtensions.sets(options)),
748+
REGEX(
749+
(options, version) -> CelExtensions.regex(),
750+
(options, version) -> CelExtensions.regex()),
737751
LISTS((options, version) -> CelExtensions.lists(), (options, version) -> CelExtensions.lists()),
738752
COMPREHENSIONS(
739753
(options, version) -> CelExtensions.comprehensions(),
740-
(options, version) -> CelExtensions.comprehensions());
754+
(options, version) -> CelExtensions.comprehensions())
755+
;
741756

742757
@SuppressWarnings("ImmutableEnumChecker")
743758
private final CompilerExtensionProvider compilerExtensionProvider;

bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ private VariableDecl parseVariable(ParserContext<Node> ctx, Node node) {
243243
case "name":
244244
builder.setName(newString(ctx, valueNode));
245245
break;
246+
case "description":
247+
builder.setDescription(newString(ctx, valueNode));
248+
break;
246249
case "type":
247250
if (typeDeclBuilder != null) {
248251
ctx.reportError(
@@ -318,6 +321,9 @@ private FunctionDecl parseFunction(ParserContext<Node> ctx, Node node) {
318321
case "overloads":
319322
builder.setOverloads(parseOverloads(ctx, valueNode));
320323
break;
324+
case "description":
325+
// TODO: Set description
326+
break;
321327
default:
322328
ctx.reportError(keyId, String.format("Unsupported function tag: %s", keyName));
323329
break;
@@ -369,6 +375,9 @@ private static ImmutableSet<OverloadDecl> parseOverloads(ParserContext<Node> ctx
369375
case "target":
370376
overloadDeclBuilder.setTarget(parseTypeDecl(ctx, valueNode));
371377
break;
378+
case "examples":
379+
// TODO: Set examples
380+
break;
372381
default:
373382
ctx.reportError(keyId, String.format("Unsupported overload tag: %s", fieldName));
374383
break;

bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlSerializer.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,9 @@ public Node representData(Object data) {
7777
if (!environment.description().isEmpty()) {
7878
configMap.put("description", environment.description());
7979
}
80-
if (!environment.container().name().isEmpty()
81-
|| !environment.container().abbreviations().isEmpty()
82-
|| !environment.container().aliases().isEmpty()) {
83-
configMap.put("container", environment.container());
80+
81+
if (environment.container().isPresent()) {
82+
configMap.put("container", environment.container().get());
8483
}
8584
if (!environment.extensions().isEmpty()) {
8685
configMap.put("extensions", environment.extensions().asList());

tools/ai/BUILD.bazel

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,21 @@ package(
55
default_visibility = ["//visibility:public"],
66
)
77

8+
java_library(
9+
name = "agentic_policy_environment",
10+
exports = ["//tools/src/main/java/dev/cel/tools/ai:agentic_policy_environment"],
11+
)
12+
813
java_library(
914
name = "agentic_policy_compiler",
1015
exports = ["//tools/src/main/java/dev/cel/tools/ai:agentic_policy_compiler"],
1116
)
1217

18+
alias(
19+
name = "ai_environments",
20+
actual = "//tools/src/main/resources/environment:ai_environments",
21+
)
22+
1323
alias(
1424
name = "test_policies",
1525
testonly = True,

tools/src/main/java/dev/cel/tools/ai/AgenticPolicyCompiler.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static dev.cel.common.formats.YamlHelper.assertYamlType;
44

5+
import com.google.protobuf.Descriptors.FileDescriptor;
56
import dev.cel.bundle.Cel;
67
import dev.cel.common.CelAbstractSyntaxTree;
78
import dev.cel.common.formats.ValueString;
@@ -66,7 +67,9 @@ public void visitPolicyTag(
6667
break;
6768

6869
case "variables":
69-
if (!assertYamlType(ctx, id, node, YamlNodeType.LIST)) return;
70+
if (!assertYamlType(ctx, id, node, YamlNodeType.LIST)) {
71+
return;
72+
}
7073
List<Variable> parsedVariables = new ArrayList<>();
7174
SequenceNode varList = (SequenceNode) node;
7275

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package dev.cel.tools.ai;
2+
3+
import static java.nio.charset.StandardCharsets.UTF_8;
4+
5+
import com.google.common.base.Ascii;
6+
import com.google.common.collect.ImmutableCollection;
7+
import com.google.common.collect.ImmutableList;
8+
import com.google.common.collect.ImmutableSet;
9+
import com.google.common.io.Resources;
10+
import com.google.protobuf.Descriptors.FileDescriptor;
11+
import dev.cel.bundle.Cel;
12+
import dev.cel.bundle.CelEnvironment;
13+
import dev.cel.bundle.CelEnvironmentException;
14+
import dev.cel.bundle.CelEnvironmentYamlParser;
15+
import dev.cel.bundle.CelFactory;
16+
import dev.cel.common.CelContainer;
17+
import dev.cel.common.CelOptions;
18+
import dev.cel.common.types.CelType;
19+
import dev.cel.common.types.CelTypeProvider;
20+
import dev.cel.common.types.OpaqueType;
21+
import dev.cel.expr.ai.Agent;
22+
import dev.cel.expr.ai.AgentMessage;
23+
import dev.cel.expr.ai.AgentMessage.Part;
24+
import dev.cel.expr.ai.ClassificationLabel;
25+
import dev.cel.expr.ai.Finding;
26+
import dev.cel.parser.CelStandardMacro;
27+
import dev.cel.runtime.CelFunctionBinding;
28+
import java.io.IOException;
29+
import java.net.URL;
30+
import java.util.ArrayList;
31+
import java.util.List;
32+
import java.util.Optional;
33+
34+
final class AgenticPolicyEnvironment {
35+
36+
private static final CelOptions CEL_OPTIONS =
37+
CelOptions.current()
38+
.enableTimestampEpoch(true)
39+
.populateMacroCalls(true)
40+
.build();
41+
42+
private static final Cel CEL_BASE_ENV =
43+
CelFactory.standardCelBuilder()
44+
.setContainer(CelContainer.ofName("cel.expr.ai")) // TODO: config?
45+
.addFileTypes(Agent.getDescriptor().getFile())
46+
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
47+
.setTypeProvider(new AgentTypeProvider())
48+
.addFunctionBindings(
49+
CelFunctionBinding.from(
50+
"AgentMessage_threatFindings",
51+
ImmutableList.of(AgentMessage.class),
52+
(args) -> getFindings((AgentMessage) args[0], "threats", ClassificationLabel.Category.THREAT)
53+
),
54+
CelFunctionBinding.from(
55+
"ai.finding_string_double",
56+
ImmutableList.of(String.class, Double.class),
57+
(args) -> Finding.newBuilder()
58+
.setValue((String) args[0])
59+
.setConfidence((Double) args[1])
60+
.build()
61+
),
62+
CelFunctionBinding.from(
63+
"optional_type(list(Finding))_hasAll_list(Finding)",
64+
ImmutableList.of(Optional.class, List.class),
65+
(args) -> hasAllFindings((Optional<List<Finding>>) args[0], (List<Finding>) args[1])
66+
)
67+
)
68+
.setOptions(CEL_OPTIONS)
69+
.build();
70+
71+
private static Optional<List<Finding>> getFindings(AgentMessage msg, String labelName, ClassificationLabel.Category category) {
72+
List<Finding> results = new ArrayList<>();
73+
74+
for (Part part : msg.getPartsList()) {
75+
if (part.hasPrompt()) {
76+
// TODO: Collect from classification
77+
results.add(Finding.newBuilder().setValue("prompt_injection").setConfidence(1.0d).build());
78+
} else if (part.hasToolCall()) {
79+
// TODO: Collect from classification
80+
}
81+
82+
}
83+
84+
if (results.isEmpty()) {
85+
return Optional.empty();
86+
}
87+
88+
return Optional.of(results);
89+
}
90+
91+
private static boolean hasAllFindings(Optional<List<Finding>> sourceOpt, List<Finding> required) {
92+
if (!sourceOpt.isPresent()) {
93+
return false;
94+
}
95+
List<Finding> source = sourceOpt.get();
96+
97+
return required.stream().allMatch(req ->
98+
source.stream().anyMatch(act ->
99+
act.getValue().equals(req.getValue()) &&
100+
act.getConfidence() >= req.getConfidence()
101+
)
102+
);
103+
}
104+
105+
static Cel newInstance() {
106+
Cel celEnv = CEL_BASE_ENV;
107+
108+
celEnv = extendFromConfig(celEnv, "environment/agent_env.yaml");
109+
celEnv = extendFromConfig(celEnv, "environment/common_env.yaml");
110+
return extendFromConfig(celEnv, "environment/tool_call_env.yaml");
111+
}
112+
113+
private static Cel extendFromConfig(Cel cel, String yamlConfigPath) {
114+
String yamlEnv;
115+
try {
116+
yamlEnv = readFile(yamlConfigPath);
117+
} catch (IOException e) {
118+
String errorMsg = String.format("Failed to read %s: %s", yamlConfigPath, e.getMessage());
119+
throw new IllegalArgumentException(errorMsg, e);
120+
}
121+
try {
122+
CelEnvironment env = CelEnvironmentYamlParser.newInstance().parse(yamlEnv);
123+
return env.extend(cel, CEL_OPTIONS);
124+
} catch (CelEnvironmentException e) {
125+
String errorMsg = String.format("Failed to extend CEL environment from %s: %s", yamlConfigPath, e.getMessage());
126+
throw new IllegalArgumentException(errorMsg, e);
127+
}
128+
}
129+
130+
private static String readFile(String path) throws IOException {
131+
URL url = Resources.getResource(Ascii.toLowerCase(path));
132+
return Resources.toString(url, UTF_8);
133+
}
134+
135+
private static final class AgentTypeProvider implements CelTypeProvider {
136+
private static final OpaqueType AGENT_MESSAGE_SET_TYPE = OpaqueType.create("cel.expr.ai.AgentMessageSet");
137+
138+
private static final ImmutableSet<CelType> ALL_TYPES = ImmutableSet.of(AGENT_MESSAGE_SET_TYPE);
139+
140+
@Override
141+
public ImmutableCollection<CelType> types() {
142+
return ALL_TYPES;
143+
}
144+
@Override
145+
public Optional<CelType> findType(String typeName) {
146+
if (typeName.equals(AGENT_MESSAGE_SET_TYPE.name())) {
147+
return Optional.of(AGENT_MESSAGE_SET_TYPE);
148+
}
149+
150+
return Optional.empty();
151+
}
152+
}
153+
154+
private AgenticPolicyEnvironment() {}
155+
}

tools/src/main/java/dev/cel/tools/ai/BUILD.bazel

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@ package(
66
"//:license",
77
],
88
default_visibility = ["//visibility:public"],
9-
# default_visibility = [
10-
# "//tools/ai:__pkg__",
11-
# ],
9+
# default_visibility = [
10+
# "//tools/ai:__pkg__",
11+
# ],
1212
)
1313

1414
java_library(
1515
name = "agentic_policy_compiler",
1616
srcs = ["AgenticPolicyCompiler.java"],
1717
deps = [
1818
":agent_context_java_proto",
19+
":agentic_policy_environment",
1920
"//bundle:cel",
2021
"//common:cel_ast",
2122
"//common/formats:value_string",
@@ -33,6 +34,28 @@ java_library(
3334
],
3435
)
3536

37+
java_library(
38+
name = "agentic_policy_environment",
39+
srcs = ["AgenticPolicyEnvironment.java"],
40+
resources = ["//tools/ai:ai_environments"],
41+
deps = [
42+
":agent_context_extensions_java_proto",
43+
":agent_context_java_proto",
44+
"//bundle:cel",
45+
"//bundle:environment",
46+
"//bundle:environment_exception",
47+
"//bundle:environment_yaml_parser",
48+
"//common:container",
49+
"//common:options",
50+
"//common/types",
51+
"//common/types:type_providers",
52+
"//parser:macro",
53+
"//runtime:function_binding",
54+
"@maven//:com_google_guava_guava",
55+
"@maven//:com_google_protobuf_protobuf_java",
56+
],
57+
)
58+
3659
proto_library(
3760
name = "agent_context_proto",
3861
srcs = ["agent_context.proto"],

0 commit comments

Comments
 (0)