Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/main/java/com/google/genai/BraintrustApiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ private void tagSpan(
@Nullable String genAIEndpoint,
@Nullable String requestMethod,
@Nullable String requestBody,
@Nullable String responseBody) {
@Nullable String responseBody,
double timeToFirstToken) {
try {
Map<String, Object> metadata = new java.util.HashMap<>();
metadata.put("provider", "gemini");
Expand Down Expand Up @@ -126,9 +127,13 @@ private void tagSpan(
"braintrust.output_json", JSON_MAPPER.writeValueAsString(responseJson));

// Parse usage metadata for metrics
Map<String, Number> metrics = new java.util.HashMap<>();

// Always add time_to_first_token
metrics.put("time_to_first_token", timeToFirstToken);

if (responseJson.get("usageMetadata") instanceof Map) {
var usage = (Map<String, Object>) responseJson.get("usageMetadata");
Map<String, Number> metrics = new java.util.HashMap<>();

if (usage.containsKey("promptTokenCount")) {
metrics.put("prompt_tokens", (Number) usage.get("promptTokenCount"));
Expand All @@ -145,10 +150,10 @@ private void tagSpan(
"prompt_cached_tokens",
(Number) usage.get("cachedContentTokenCount"));
}

span.setAttribute(
"braintrust.metrics", JSON_MAPPER.writeValueAsString(metrics));
}

// Always set metrics (at minimum with time_to_first_token)
span.setAttribute("braintrust.metrics", JSON_MAPPER.writeValueAsString(metrics));
}

// Set metadata
Expand Down Expand Up @@ -195,10 +200,19 @@ public ApiResponse request(
Span span =
tracer.spanBuilder(getOperation(genAIUrl)).setSpanKind(SpanKind.CLIENT).startSpan();
try (Scope scope = span.makeCurrent()) {
long startTimeNanos = System.nanoTime();
ApiResponse response = delegate.request(requestMethod, genAIUrl, requestBody, options);
double timeToFirstToken = (System.nanoTime() - startTimeNanos) / 1_000_000_000.0;

BufferedApiResponse bufferedResponse = new BufferedApiResponse(response);
span.setStatus(StatusCode.OK);
tagSpan(span, genAIUrl, requestMethod, requestBody, bufferedResponse.getBodyAsString());
tagSpan(
span,
genAIUrl,
requestMethod,
requestBody,
bufferedResponse.getBodyAsString(),
timeToFirstToken);
return bufferedResponse;
} catch (Throwable t) {
span.setStatus(StatusCode.ERROR, t.getMessage());
Expand All @@ -219,16 +233,20 @@ public ApiResponse request(
Span span =
tracer.spanBuilder(getOperation(genAIUrl)).setSpanKind(SpanKind.CLIENT).startSpan();
try (Scope scope = span.makeCurrent()) {
long startTimeNanos = System.nanoTime();
ApiResponse response =
delegate.request(requestMethod, genAIUrl, requestBodyBytes, options);
double timeToFirstToken = (System.nanoTime() - startTimeNanos) / 1_000_000_000.0;

BufferedApiResponse bufferedResponse = new BufferedApiResponse(response);
span.setStatus(StatusCode.OK);
tagSpan(
span,
genAIUrl,
requestMethod,
new String(requestBodyBytes),
bufferedResponse.getBodyAsString());
bufferedResponse.getBodyAsString(),
timeToFirstToken);
return bufferedResponse;
} catch (Throwable t) {
span.setStatus(StatusCode.ERROR, t.getMessage());
Expand All @@ -244,6 +262,7 @@ public CompletableFuture<ApiResponse> asyncRequest(
String method, String url, String body, Optional<HttpOptions> options) {
Span span = tracer.spanBuilder(getOperation(url)).setSpanKind(SpanKind.CLIENT).startSpan();
Context context = Context.current().with(span);
long startTimeNanos = System.nanoTime();

return delegate.asyncRequest(method, url, body, options)
.handle(
Expand All @@ -256,6 +275,9 @@ public CompletableFuture<ApiResponse> asyncRequest(
}

try {
double timeToFirstToken =
(System.nanoTime() - startTimeNanos) / 1_000_000_000.0;

// Buffer the response so we can read it for instrumentation
BufferedApiResponse bufferedResponse =
new BufferedApiResponse(response);
Expand All @@ -265,7 +287,8 @@ public CompletableFuture<ApiResponse> asyncRequest(
url,
method,
body,
bufferedResponse.getBodyAsString());
bufferedResponse.getBodyAsString(),
timeToFirstToken);
return (ApiResponse) bufferedResponse;
} catch (Exception e) {
span.setStatus(StatusCode.ERROR, e.getMessage());
Expand All @@ -283,6 +306,7 @@ public CompletableFuture<ApiResponse> asyncRequest(
String method, String url, byte[] body, Optional<HttpOptions> options) {
Span span = tracer.spanBuilder(getOperation(url)).setSpanKind(SpanKind.CLIENT).startSpan();
Context context = Context.current().with(span);
long startTimeNanos = System.nanoTime();

return delegate.asyncRequest(method, url, body, options)
.handle(
Expand All @@ -295,6 +319,9 @@ public CompletableFuture<ApiResponse> asyncRequest(
}

try {
double timeToFirstToken =
(System.nanoTime() - startTimeNanos) / 1_000_000_000.0;

// Buffer the response so we can read it for instrumentation
BufferedApiResponse bufferedResponse =
new BufferedApiResponse(response);
Expand All @@ -304,7 +331,8 @@ public CompletableFuture<ApiResponse> asyncRequest(
url,
method,
new String(body),
bufferedResponse.getBodyAsString());
bufferedResponse.getBodyAsString(),
timeToFirstToken);
return (ApiResponse) bufferedResponse;
} catch (Exception e) {
span.setStatus(StatusCode.ERROR, e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ void testWrapGemini() {
assertTrue(metrics.get("prompt_tokens").asInt() > 0, "prompt_tokens should be > 0");
assertTrue(metrics.get("completion_tokens").asInt() > 0, "completion_tokens should be > 0");
assertTrue(metrics.get("tokens").asInt() > 0, "tokens should be > 0");
assertTrue(metrics.has("time_to_first_token"), "time_to_first_token should be present");
assertTrue(
metrics.get("time_to_first_token").asDouble() >= 0.0,
"time_to_first_token should be >= 0");

// Verify braintrust.span_attributes marks this as an LLM span
String spanAttributesJson =
Expand Down Expand Up @@ -153,6 +157,10 @@ void testWrapGeminiAsync() {
assertTrue(metrics.get("prompt_tokens").asInt() > 0, "prompt_tokens should be > 0");
assertTrue(metrics.get("completion_tokens").asInt() > 0, "completion_tokens should be > 0");
assertTrue(metrics.get("tokens").asInt() > 0, "tokens should be > 0");
assertTrue(metrics.has("time_to_first_token"), "time_to_first_token should be present");
assertTrue(
metrics.get("time_to_first_token").asDouble() >= 0.0,
"time_to_first_token should be >= 0");

// Verify braintrust.span_attributes marks this as an LLM span
String spanAttributesJson =
Expand Down