Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,51 @@ public void write(
final AIGuard.Message value, final Writable writable, final EncodingCache encodingCache) {
final int[] size = {0};
final boolean hasRole = isNotBlank(value.getRole(), size);
final boolean hasContent = isNotBlank(value.getContent(), size);
final boolean hasToolCallId = isNotBlank(value.getToolCallId(), size);
final boolean hasToolCalls = isNotEmpty(value.getToolCalls(), size);

final boolean hasContentParts = isNotEmpty(value.getContentParts(), size);
final boolean hasContent = !hasContentParts && isNotBlank(value.getContent(), size);

writable.startMap(size[0]);
writeString(hasRole, "role", value.getRole(), writable, encodingCache);
writeString(hasContent, "content", value.getContent(), writable, encodingCache);

if (hasContentParts) {
writeContentParts("content", value.getContentParts(), writable, encodingCache);
} else {
writeString(hasContent, "content", value.getContent(), writable, encodingCache);
}

writeString(hasToolCallId, "tool_call_id", value.getToolCallId(), writable, encodingCache);
writeToolCallArray(hasToolCalls, "tool_calls", value.getToolCalls(), writable, encodingCache);
}

private static void writeContentParts(
final String key,
final List<AIGuard.ContentPart> contentParts,
final Writable writable,
final EncodingCache encodingCache) {
writable.writeString(key, encodingCache);
writable.startArray(contentParts.size());

for (final AIGuard.ContentPart part : contentParts) {
writable.startMap(2);

writable.writeString("type", encodingCache);
writable.writeString(part.getType().toString(), encodingCache);

if (part.getType() == AIGuard.ContentPart.Type.TEXT) {
writable.writeString("text", encodingCache);
writable.writeString(part.getText(), encodingCache);
} else if (part.getType() == AIGuard.ContentPart.Type.IMAGE_URL) {
writable.writeString("image_url", encodingCache);
writable.startMap(1);
writable.writeString("url", encodingCache);
writable.writeString(part.getImageUrl().getUrl(), encodingCache);
}
}
}

private static void writeString(
final boolean present,
final String key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,127 @@ class MessageWriterTest extends DDSpecification {
private static String asString(final Value value) {
return value.asStringValue().asString()
}

void 'test write message with text content parts'() {
given:
final message = AIGuard.Message.message('user', [
AIGuard.ContentPart.text('Hello world')
])

when:
writer.writeObject(message, encodingCache)

then:
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
final value = asStringKeyMap(unpacker.unpackValue())
value.size() == 2
asString(value.role) == 'user'

final contentParts = value.content.asArrayValue().list()
contentParts.size() == 1

final part = asStringKeyMap(contentParts[0])
asString(part.type) == 'text'
asString(part.text) == 'Hello world'
}
}

void 'test write message with image_url content parts'() {
given:
final message = AIGuard.Message.message('user', [
AIGuard.ContentPart.imageUrl('https://example.com/image.jpg')
])

when:
writer.writeObject(message, encodingCache)

then:
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
final value = asStringKeyMap(unpacker.unpackValue())
value.size() == 2
asString(value.role) == 'user'

final contentParts = value.content.asArrayValue().list()
contentParts.size() == 1

final part = asStringKeyMap(contentParts[0])
asString(part.type) == 'image_url'

final imageUrl = asStringKeyMap(part.image_url)
asString(imageUrl.url) == 'https://example.com/image.jpg'
}
}

void 'test write message with mixed content parts'() {
given:
final message = AIGuard.Message.message('user', [
AIGuard.ContentPart.text('Describe this:'),
AIGuard.ContentPart.imageUrl('https://example.com/image.jpg'),
AIGuard.ContentPart.text('What is it?')
])

when:
writer.writeObject(message, encodingCache)

then:
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
final value = asStringKeyMap(unpacker.unpackValue())
value.size() == 2
asString(value.role) == 'user'

final contentParts = value.content.asArrayValue().list()
contentParts.size() == 3

final part1 = asStringKeyMap(contentParts[0])
asString(part1.type) == 'text'
asString(part1.text) == 'Describe this:'

final part2 = asStringKeyMap(contentParts[1])
asString(part2.type) == 'image_url'
final imageUrl = asStringKeyMap(part2.image_url)
asString(imageUrl.url) == 'https://example.com/image.jpg'

final part3 = asStringKeyMap(contentParts[2])
asString(part3.type) == 'text'
asString(part3.text) == 'What is it?'
}
}

void 'test content parts type serializes as string not integer'() {
given:
final message = AIGuard.Message.message('user', [
AIGuard.ContentPart.text('Test')
])

when:
writer.writeObject(message, encodingCache)

then:
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
final value = asStringKeyMap(unpacker.unpackValue())
final contentParts = value.content.asArrayValue().list()
final part = asStringKeyMap(contentParts[0])

// Verify type is a string value, not an integer
part.type.isStringValue()
!part.type.isIntegerValue()
asString(part.type) == 'text'
}
}

void 'test backward compatibility with string content'() {
given:
final message = AIGuard.Message.message('user', 'Plain text message')

when:
writer.writeObject(message, encodingCache)

then:
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
final value = asStringValueMap(unpacker.unpackValue())
value.size() == 2
value.role == 'user'
value.content == 'Plain text message'
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import datadog.trace.api.aiguard.AIGuard.AIGuardAbortError;
import datadog.trace.api.aiguard.AIGuard.AIGuardClientError;
import datadog.trace.api.aiguard.AIGuard.Action;
import datadog.trace.api.aiguard.AIGuard.ContentPart;
import datadog.trace.api.aiguard.AIGuard.Evaluation;
import datadog.trace.api.aiguard.AIGuard.Message;
import datadog.trace.api.aiguard.AIGuard.Options;
Expand Down Expand Up @@ -136,16 +137,44 @@ private static List<Message> messagesForMetaStruct(List<Message> messages) {
boolean contentTruncated = false;
for (int i = messages.size() - size; i < messages.size(); i++) {
final Message source = messages.get(i);
String content = source.getContent();
if (content != null && content.length() > maxContent) {
contentTruncated = true;
content = content.substring(0, maxContent);
}
List<ToolCall> toolCalls = source.getToolCalls();
if (toolCalls != null) {
toolCalls = new ArrayList<>(toolCalls);

List<ContentPart> contentParts = source.getContentParts();
if (contentParts != null) {
final List<ContentPart> truncatedParts = new ArrayList<>(contentParts.size());
for (final ContentPart part : contentParts) {
if (part.getType() == ContentPart.Type.TEXT) {
String text = part.getText();
if (text != null && text.length() > maxContent) {
contentTruncated = true;
text = text.substring(0, maxContent);
truncatedParts.add(ContentPart.text(text));
} else {
truncatedParts.add(part);
}
} else {
truncatedParts.add(part);
}
}

List<ToolCall> toolCalls = source.getToolCalls();
if (toolCalls != null) {
toolCalls = new ArrayList<>(toolCalls);
}
result.add(
new Message(source.getRole(), truncatedParts, toolCalls, source.getToolCallId()));
} else {
// Handle plain text content (backward compatibility)
String content = source.getContent();
if (content != null && content.length() > maxContent) {
contentTruncated = true;
content = content.substring(0, maxContent);
}
List<ToolCall> toolCalls = source.getToolCalls();
if (toolCalls != null) {
toolCalls = new ArrayList<>(toolCalls);
}
result.add(new Message(source.getRole(), content, toolCalls, source.getToolCallId()));
}
result.add(new Message(source.getRole(), content, toolCalls, source.getToolCallId()));
}
if (contentTruncated) {
WafMetricCollector.get().aiGuardTruncated(CONTENT);
Expand Down Expand Up @@ -333,12 +362,45 @@ public Message fromJson(JsonReader reader) throws IOException {
public void toJson(final JsonWriter writer, final Message value) throws IOException {
writer.beginObject();
writeValue(writer, "role", value.getRole());
writeValue(writer, "content", value.getContent());

if (value.getContentParts() != null) {
writeContentParts(writer, "content", value.getContentParts());
} else {
writeValue(writer, "content", value.getContent());
}

writeArray(writer, "tool_calls", value.getToolCalls());
writeValue(writer, "tool_call_id", value.getToolCallId());
writer.endObject();
}

private void writeContentParts(
final JsonWriter writer, final String name, final List<ContentPart> contentParts)
throws IOException {
writer.name(name);
writer.beginArray();
for (final ContentPart part : contentParts) {
writer.beginObject();

writer.name("type");
writer.value(part.getType().toString());

if (part.getType() == ContentPart.Type.TEXT) {
writer.name("text");
writer.value(part.getText());
} else if (part.getType() == ContentPart.Type.IMAGE_URL) {
writer.name("image_url");
writer.beginObject();
writer.name("url");
writer.value(part.getImageUrl().getUrl());
writer.endObject();
}

writer.endObject();
}
writer.endArray();
}

private void writeValue(final JsonWriter writer, final String name, final Object value)
throws IOException {
if (value != null) {
Expand Down
Loading
Loading