From 2e439361fc338cf8581b68b26acd9b8036b9a2b9 Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Thu, 30 Apr 2026 12:31:05 -0600 Subject: [PATCH] feature: traced scorers --- .../api/BraintrustOpenApiClient.java | 51 +- .../api/BtqlRateLimitException.java | 12 + .../main/java/dev/braintrust/eval/Eval.java | 29 +- .../dev/braintrust/eval/TracedScorer.java | 41 ++ .../dev/braintrust/trace/BrainstoreTrace.java | 357 ++++++++++++++ .../braintrust/trace/BraintrustTracing.java | 14 + .../braintrust/eval/TracedScorerEvalTest.java | 323 +++++++++++++ .../braintrust/trace/BrainstoreTraceTest.java | 443 ++++++++++++++++++ .../braintrust/sdkspecimpl/SpanFetcher.java | 261 ++--------- .../examples/TraceScoringExample.java | 127 +++++ 10 files changed, 1436 insertions(+), 222 deletions(-) create mode 100644 braintrust-sdk/src/main/java/dev/braintrust/api/BtqlRateLimitException.java create mode 100644 braintrust-sdk/src/main/java/dev/braintrust/eval/TracedScorer.java create mode 100644 braintrust-sdk/src/main/java/dev/braintrust/trace/BrainstoreTrace.java create mode 100644 braintrust-sdk/src/test/java/dev/braintrust/eval/TracedScorerEvalTest.java create mode 100644 braintrust-sdk/src/test/java/dev/braintrust/trace/BrainstoreTraceTest.java create mode 100644 examples/src/main/java/dev/braintrust/examples/TraceScoringExample.java diff --git a/braintrust-sdk/src/main/java/dev/braintrust/api/BraintrustOpenApiClient.java b/braintrust-sdk/src/main/java/dev/braintrust/api/BraintrustOpenApiClient.java index e7cb09be..5c388ab9 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/api/BraintrustOpenApiClient.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/api/BraintrustOpenApiClient.java @@ -145,12 +145,18 @@ public BtqlQueryResponse btqlQuery(String query) { getHttpClient() .send(requestBuilder.build(), HttpResponse.BodyHandlers.ofString()); - if (response.statusCode() / 100 != 2) { - throw new RuntimeException( - "BTQL query failed with status " - + response.statusCode() - + ": " - + response.body()); + if (response.statusCode() == 429) { + throw new BtqlRateLimitException( + response.statusCode(), + "BTQL rate limit exceeded", + response.headers(), + response.body()); + } else if (response.statusCode() / 100 != 2) { + throw new ApiException( + response.statusCode(), + "BTQL query failed", + response.headers(), + response.body()); } return MAPPER.readValue(response.body(), BtqlQueryResponse.class); @@ -163,7 +169,38 @@ public BtqlQueryResponse btqlQuery(String query) { public record OrgInfo(String id, String name) {} - public record BtqlQueryResponse(List> data) {} + /** + * Response from a {@code POST /btql} query. + * + *

Freshness is determined by comparing {@link FreshnessState#lastProcessedXactId()} to + * {@link FreshnessState#lastConsideredXactId()}: when both are non-null and equal, the query + * has caught up to all ingested data and the result is fresh. + * + *

The {@link RealtimeState#type()} field indicates whether realtime indexing is still active + * ({@code "on"}) or has timed out ({@code "exhausted_timeout"}). + */ + public record BtqlQueryResponse( + List> data, + @JsonProperty("freshness_state") FreshnessState freshnessState, + @JsonProperty("realtime_state") RealtimeState realtimeState) { + + /** Returns {@code true} when the query result has caught up to all ingested data. */ + public boolean isFresh() { + if (freshnessState == null) { + return false; + } + var processed = freshnessState.lastProcessedXactId(); + var considered = freshnessState.lastConsideredXactId(); + return processed != null && processed.equals(considered); + } + } + + public record FreshnessState( + @JsonProperty("last_processed_xact_id") String lastProcessedXactId, + @JsonProperty("last_considered_xact_id") String lastConsideredXactId) {} + + /** Real-time indexing state for a BTQL query. */ + public record RealtimeState(@JsonProperty("type") String type) {} private record LoginRequest(String token) {} diff --git a/braintrust-sdk/src/main/java/dev/braintrust/api/BtqlRateLimitException.java b/braintrust-sdk/src/main/java/dev/braintrust/api/BtqlRateLimitException.java new file mode 100644 index 00000000..9780a7ec --- /dev/null +++ b/braintrust-sdk/src/main/java/dev/braintrust/api/BtqlRateLimitException.java @@ -0,0 +1,12 @@ +package dev.braintrust.api; + +import dev.braintrust.openapi.ApiException; +import java.net.http.HttpHeaders; + +/** Thrown when the BTQL endpoint returns HTTP 429 (Too Many Requests). */ +public final class BtqlRateLimitException extends ApiException { + BtqlRateLimitException( + int code, String message, HttpHeaders responseHeaders, String responseBody) { + super(code, message, responseHeaders, responseBody); + } +} diff --git a/braintrust-sdk/src/main/java/dev/braintrust/eval/Eval.java b/braintrust-sdk/src/main/java/dev/braintrust/eval/Eval.java index 0286d46a..ee814baf 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/eval/Eval.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/eval/Eval.java @@ -9,6 +9,7 @@ import dev.braintrust.openapi.api.ExperimentsApi; import dev.braintrust.openapi.model.CreateExperiment; import dev.braintrust.openapi.model.Project; +import dev.braintrust.trace.BrainstoreTrace; import dev.braintrust.trace.BraintrustContext; import dev.braintrust.trace.BraintrustTracing; import io.opentelemetry.api.common.AttributeKey; @@ -124,6 +125,7 @@ private void evalOne(String experimentId, DatasetCase datasetCase } try (var rootScope = BraintrustContext.ofExperiment(experimentId, rootSpan).makeCurrent()) { final TaskResult taskResult; + final String taskSpanId; { // run task var taskSpan = tracer.spanBuilder("task") @@ -132,6 +134,7 @@ private void evalOne(String experimentId, DatasetCase datasetCase "braintrust.span_attributes", toJson(Map.of("type", "task"))) .startSpan(); + taskSpanId = taskSpan.getSpanContext().getSpanId(); try (var unused = BraintrustContext.ofExperiment(experimentId, taskSpan).makeCurrent()) { taskResult = task.apply(datasetCase, parameters); @@ -155,9 +158,19 @@ private void evalOne(String experimentId, DatasetCase datasetCase } taskSpan.end(); } + + // Create a single BrainstoreTrace for this eval case, shared across all scorers. + // It fetches spans lazily on first access (only if a TracedScorer actually calls it). + // We wait specifically for the task span to appear, which guarantees its children + // (LLM spans, tool spans) have also been indexed — since children end before parents. + var rootTraceId = rootSpan.getSpanContext().getTraceId(); + var trace = + BrainstoreTrace.forExperiment( + client, experimentId, rootTraceId, List.of(taskSpanId)); + // run scorers - one span per scorer for (var scorer : scorers) { - runScorer(experimentId, rootSpan, scorer, taskResult); + runScorer(experimentId, rootSpan, scorer, taskResult, trace); } } finally { rootSpan.end(); @@ -165,14 +178,16 @@ private void evalOne(String experimentId, DatasetCase datasetCase } /** - * Runs a scorer against a successful task result. If the scorer throws, falls back to {@link - * Scorer#scoreForScorerException}. + * Runs a scorer against a successful task result. If the scorer is a {@link TracedScorer}, it + * receives the {@link BrainstoreTrace} for the eval case. If the scorer throws, falls back to + * {@link Scorer#scoreForScorerException}. */ private void runScorer( String experimentId, Span rootSpan, Scorer scorer, - TaskResult taskResult) { + TaskResult taskResult, + BrainstoreTrace trace) { var scoreSpan = tracer.spanBuilder("score") .setAttribute(PARENT, "experiment_id:" + experimentId) @@ -180,7 +195,11 @@ private void runScorer( try (var unused = BraintrustContext.ofExperiment(experimentId, scoreSpan).makeCurrent()) { List scores; try { - scores = scorer.score(taskResult); + if (scorer instanceof TracedScorer tracedScorer) { + scores = tracedScorer.score(taskResult, trace); + } else { + scores = scorer.score(taskResult); + } } catch (Exception e) { scoreSpan.setStatus(StatusCode.ERROR, e.getMessage()); scoreSpan.recordException(e); diff --git a/braintrust-sdk/src/main/java/dev/braintrust/eval/TracedScorer.java b/braintrust-sdk/src/main/java/dev/braintrust/eval/TracedScorer.java new file mode 100644 index 00000000..352a7130 --- /dev/null +++ b/braintrust-sdk/src/main/java/dev/braintrust/eval/TracedScorer.java @@ -0,0 +1,41 @@ +package dev.braintrust.eval; + +import dev.braintrust.trace.BrainstoreTrace; +import java.util.List; + +/** + * A scorer that receives access to the full distributed trace of the task that was evaluated. + * + *

Implement this interface when your scorer needs to examine intermediate LLM calls, tool + * invocations, or other spans produced during task execution — not just the final {@code + * TaskResult}. + * + * @param type of the input data + * @param type of the output data + */ +public interface TracedScorer extends Scorer { + + /** + * Scores the task result using the distributed trace for additional context. Called instead of + * {@link Scorer#score(TaskResult)} when a {@link BrainstoreTrace} is available. + * + * @param taskResult the task output and originating dataset case + * @param trace lazy access to the distributed trace spans for this eval case + * @return one or more scores, each with a value between 0 and 1 inclusive + */ + List score(TaskResult taskResult, BrainstoreTrace trace); + + /** + * {@inheritDoc} + * + *

When used inside an {@link Eval}, this overload is never called — {@link + * #score(TaskResult, BrainstoreTrace)} is dispatched instead. This default implementation + * throws {@link UnsupportedOperationException} to surface any accidental direct calls. + */ + @Override + default List score(TaskResult taskResult) { + throw new UnsupportedOperationException( + "traced scorer score method directly called. This is likely an accident. If you" + + " wish to support this, your implementation must override this method."); + } +} diff --git a/braintrust-sdk/src/main/java/dev/braintrust/trace/BrainstoreTrace.java b/braintrust-sdk/src/main/java/dev/braintrust/trace/BrainstoreTrace.java new file mode 100644 index 00000000..361980f2 --- /dev/null +++ b/braintrust-sdk/src/main/java/dev/braintrust/trace/BrainstoreTrace.java @@ -0,0 +1,357 @@ +package dev.braintrust.trace; + +import dev.braintrust.api.BraintrustOpenApiClient; +import dev.braintrust.api.BtqlRateLimitException; +import dev.braintrust.json.BraintrustJsonMapper; +import io.opentelemetry.api.GlobalOpenTelemetry; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; + +/** + * Provides access to the distributed trace spans for a single eval task stored in Braintrust. + * + *

Spans are fetched lazily on first access and cached for subsequent calls. Score-type spans are + * excluded from all results (they are filtered out by the BTQL query). + */ +@Slf4j +public class BrainstoreTrace { + /** Maximum number of attempts (exponential backoff: 1s→2s→4s→…→30s). */ + private static final int MAX_ATTEMPTS = 8; + + /** Fixed sleep when the server returns a 429 rate-limit; the attempt is not counted. */ + private static final int MAX_RATE_LIMIT_ATTEMPTS = 30; + + private static final long RATE_LIMIT_SLEEP_MS = 30_000; + + private static final int BASE_DELAY_MS = 5_000; + private static final int MAX_DELAY_MS = 30_000; + + private final Supplier>> spansSupplier; + private final ReentrantLock lock = new ReentrantLock(); + + @Nullable private volatile List> cachedSpans; + + /** + * Creates a {@code BrainstoreTrace} that fetches spans for a trace stored in a Braintrust + * experiment. + * + *

Queries {@code experiment('')} and excludes score-type spans. The retry loop + * blocks until all {@code expectedSpanIds} are present in the results. + * + * @param client the API client used to execute BTQL queries + * @param experimentId the experiment whose spans to query + * @param rootTraceId the OTel trace ID (hex string, 32 chars) + * @param expectedSpanIds OTel span IDs (16 hex chars each) that must all appear in the results + * before the fetch is considered complete + */ + public static BrainstoreTrace forExperiment( + @Nonnull BraintrustOpenApiClient client, + @Nonnull String experimentId, + @Nonnull String rootTraceId, + @Nonnull List expectedSpanIds) { + var safeExperimentId = experimentId.replace("'", "''"); + var safeRootTraceId = rootTraceId.replace("'", "''"); + var query = + "SELECT * FROM experiment('%s') WHERE root_span_id = '%s' AND span_attributes.type != 'score' ORDER BY created ASC LIMIT 1000" + .formatted(safeExperimentId, safeRootTraceId); + return new BrainstoreTrace(() -> fetchWithRetry(client, query, expectedSpanIds)); + } + + /** + * Creates a {@code BrainstoreTrace} that fetches spans for a trace stored in Braintrust project + * logs. + * + *

Queries {@code project_logs('')} for all spans in the trace. The retry loop + * blocks until all {@code expectedSpanIds} are present in the results. + * + * @param client the API client used to execute BTQL queries + * @param projectId the project whose logs to query + * @param rootTraceId the OTel trace ID (hex string, 32 chars) + * @param expectedSpanIds OTel span IDs (16 hex chars each) that must all appear in the results + * before the fetch is considered complete + */ + public static BrainstoreTrace forTrace( + @Nonnull BraintrustOpenApiClient client, + @Nonnull String projectId, + @Nonnull String rootTraceId, + @Nonnull List expectedSpanIds) { + var safeProjectId = projectId.replace("'", "''"); + var safeRootTraceId = rootTraceId.replace("'", "''"); + var query = + "SELECT * FROM project_logs('%s') WHERE root_span_id = '%s' ORDER BY created ASC LIMIT 1000" + .formatted(safeProjectId, safeRootTraceId); + return new BrainstoreTrace(() -> fetchWithRetry(client, query, expectedSpanIds)); + } + + /** + * Creates a {@code BrainstoreTrace} backed by a custom supplier. Primarily useful for testing. + */ + BrainstoreTrace(@Nonnull Supplier>> spansSupplier) { + this.spansSupplier = spansSupplier; + } + + /** + * Returns all non-score spans for this trace. Results are fetched on first call and cached. + * + * @return an immutable list of span maps; each map contains the span's fields as returned by + * BTQL (e.g. {@code "input"}, {@code "output"}, {@code "span_attributes"}, {@code + * "start_time"}, {@code "end_time"}) + */ + public List> getSpans() { + var cached = cachedSpans; + if (cached != null) { + return cached; + } + lock.lock(); + try { + if (cachedSpans == null) { + cachedSpans = List.copyOf(spansSupplier.get()); + } + return cachedSpans; + } finally { + lock.unlock(); + } + } + + /** + * Returns spans filtered by {@code span_attributes.type}. + * + *

Common types: {@code "llm"}, {@code "task"}, {@code "eval"}, {@code "tool"}, {@code + * "function"}. + * + * @param spanType the value of {@code span_attributes.type} to filter by + * @return spans whose {@code span_attributes.type} matches {@code spanType} + */ + public List> getSpans(@Nonnull String spanType) { + return getSpans().stream().filter(span -> spanType.equals(getSpanType(span))).toList(); + } + + /** + * Reconstructs the LLM conversation thread from all LLM spans in this trace. + * + *

Flattens the span tree via pre-order DFS (parent before children, siblings in {@code + * metrics.start} order), then walks the resulting LLM-span sequence and de-duplicates using a + * seen-set: any input or output item already added to the thread is skipped. + * + * @return a flat, ordered list of message/output maps from all LLM spans in the trace + */ + public List> getLLMConversationThread() { + var allSpans = getSpans(); + + // Build children map: parent_id → children sorted by start_time + var children = new java.util.LinkedHashMap>>(); + for (var span : allSpans) { + var parents = span.get("span_parents"); + if (parents instanceof List parentList && !parentList.isEmpty()) { + if (parentList.get(0) instanceof String pid) { + children.computeIfAbsent(pid, k -> new ArrayList<>()).add(span); + } + } + } + children.values() + .forEach( + list -> + list.sort( + (a, b) -> + Double.compare(getStartTime(a), getStartTime(b)))); + + // Find root span (no parents) + var root = + allSpans.stream() + .filter( + s -> { + var p = s.get("span_parents"); + return p == null || (p instanceof List l && l.isEmpty()); + }) + .findFirst() + .orElse(null); + if (root == null) return List.of(); + + // Pre-order DFS to get all LLM spans in hierarchy order. + // Prune entire subtrees rooted at scorer spans (purpose == "scorer") — these are + // synthetic spans injected by the Braintrust backend and not part of the real trace. + var llmSpansInOrder = new ArrayList>(); + var stack = new java.util.ArrayDeque>(); + stack.push(root); + while (!stack.isEmpty()) { + var span = stack.pop(); + if ("automation".equals(getSpanType(span))) { + // prune topics and other synthetic spans + continue; + } + if ("llm".equals(getSpanType(span))) { + llmSpansInOrder.add(span); + } + // Push children in reverse order so first child is processed first + var spanId = span.get("span_id"); + if (spanId instanceof String sid) { + var childList = children.getOrDefault(sid, List.of()); + for (int i = childList.size() - 1; i >= 0; i--) { + stack.push(childList.get(i)); + } + } + } + + // Walk LLM spans in order, adding unseen input/output items + var thread = new ArrayList>(); + var seen = new java.util.LinkedHashSet(); + for (var span : llmSpansInOrder) { + for (var msg : getInputMessages(span)) { + if (seen.add(msg)) { + thread.add(msg); + } + } + for (var out : getOutputMessages(span)) { + if (seen.add(out)) { + thread.add(out); + } + } + } + return List.copyOf(thread); + } + + // ------------------------------------------------------------------------- + // Private helpers + // ------------------------------------------------------------------------- + + @Nullable + private static String getSpanType(Map span) { + var attrs = span.get("span_attributes"); + if (attrs instanceof Map attrsMap) { + var type = attrsMap.get("type"); + return type instanceof String s ? s : null; + } + return null; + } + + private static double getStartTime(Map span) { + var t = span.get("metrics"); + if (t instanceof Map metrics) { + var start = metrics.get("start"); + if (start instanceof Number n) return n.doubleValue(); + } + return Double.MAX_VALUE; + } + + @SuppressWarnings("unchecked") + private static List> getInputMessages(Map span) { + var input = span.get("input"); + // Input may be a raw List or a JSON-encoded string + if (input instanceof List inputList) { + return inputList.isEmpty() ? List.of() : (List>) inputList; + } + if (input instanceof String s && !s.isBlank()) { + try { + var parsed = BraintrustJsonMapper.fromJson(s, List.class); + if (parsed instanceof List l) { + return (List>) l; + } + } catch (Exception e) { + log.debug("could not parse input as JSON array: {}", e.getMessage()); + } + } + return List.of(); + } + + @SuppressWarnings("unchecked") + private static List> getOutputMessages(Map span) { + var output = span.get("output"); + if (output instanceof List outputList) { + return outputList.isEmpty() ? List.of() : (List>) outputList; + } + log.debug("unexpected output type: {}", BraintrustJsonMapper.toJson(output)); + return List.of(); + } + + /** + * Polls BTQL with the given {@code query}, retrying with exponential backoff until all {@code + * expectedSpanIds} appear in the results, or until max attempts are exhausted. + * + *

429 rate-limit responses are handled transparently: the thread sleeps {@link + * #RATE_LIMIT_SLEEP_MS} and the attempt is retried without consuming an attempt slot. + */ + @SneakyThrows + private static List> fetchWithRetry( + BraintrustOpenApiClient client, String query, List expectedSpanIds) { + + BraintrustOpenApiClient.BtqlQueryResponse lastResponse = null; + int delayMs = BASE_DELAY_MS; + + for (int attempt = 0; attempt < MAX_ATTEMPTS; attempt++) { + if (attempt > 0) { + try { + Thread.sleep(delayMs); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + log.warn("BrainstoreTrace: interrupted while waiting for expected spans"); + break; + } + delayMs = Math.min(delayMs * 2, MAX_DELAY_MS); + } + + // Retry rate-limit responses without consuming an attempt slot. + for (int rateLimitAttempt = 0; + rateLimitAttempt < MAX_RATE_LIMIT_ATTEMPTS; + ++rateLimitAttempt) { + try { + lastResponse = client.btqlQuery(query); + break; + } catch (BtqlRateLimitException e) { + if (rateLimitAttempt == MAX_RATE_LIMIT_ATTEMPTS - 1) { + log.error( + "Failed to fetch spans from Braintrust. Max attempts exceeded." + + " Giving up."); + throw e; + } + log.debug( + "BrainstoreTrace: rate limited, sleeping {}ms then retrying: {}", + RATE_LIMIT_SLEEP_MS, + e.getMessage()); + Thread.sleep(RATE_LIMIT_SLEEP_MS); + } + } + if (lastResponse == null) break; + + var presentSpanIds = + lastResponse.data().stream() + .map(row -> row.get("span_id")) + .filter(id -> id instanceof String) + .map(id -> (String) id) + .collect(Collectors.toSet()); + var missingSpanIds = + expectedSpanIds.stream().filter(id -> !presentSpanIds.contains(id)).toList(); + + log.debug( + "BrainstoreTrace BTQL attempt {}/{}: rows={}, missing={}/{}", + attempt + 1, + MAX_ATTEMPTS, + lastResponse.data().size(), + missingSpanIds.size(), + expectedSpanIds.size()); + + if (missingSpanIds.isEmpty()) { + break; + } else if (attempt == 0) { + // OPTIMIZATION: force flush otel to get data into braintrust faster + BraintrustTracing.attemptForceFlush(GlobalOpenTelemetry.get()); + } + + if (attempt >= (MAX_ATTEMPTS - 1)) { + throw new RuntimeException( + ("BrainstoreTrace: max attempts reached waiting for expected spans. " + + "missing span IDs: %s") + .formatted(missingSpanIds)); + } + } + + return lastResponse == null ? List.of() : List.copyOf(lastResponse.data()); + } +} diff --git a/braintrust-sdk/src/main/java/dev/braintrust/trace/BraintrustTracing.java b/braintrust-sdk/src/main/java/dev/braintrust/trace/BraintrustTracing.java index 89380463..f38bd3da 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/trace/BraintrustTracing.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/trace/BraintrustTracing.java @@ -162,6 +162,20 @@ public static Tracer getTracer(TracerProvider tracerProvider) { return tracerProvider.get(INSTRUMENTATION_NAME, INSTRUMENTATION_VERSION); } + /** + * Attempt to flush the all data from the given open telemetry instance. Does not block + * + * @param openTelemetry the {@link OpenTelemetry} instance whose spans to attempt flushing + */ + static void attemptForceFlush(@Nonnull OpenTelemetry openTelemetry) { + if (openTelemetry instanceof OpenTelemetrySdk sdk) { + synchronized (BraintrustTracing.class) { + sdk.getSdkTracerProvider().forceFlush(); + log.debug("sdk tracer forceFlush initiated"); + } + } + } + private static String sdkInfoLogMessage() { return "Initializing Braintrust OpenTelemetry with service=%s, instrumentation-name=%s, instrumentation-version=%s, jvm-version=%s, jvm-vendor=%s, jvm-name=%s" .formatted( diff --git a/braintrust-sdk/src/test/java/dev/braintrust/eval/TracedScorerEvalTest.java b/braintrust-sdk/src/test/java/dev/braintrust/eval/TracedScorerEvalTest.java new file mode 100644 index 00000000..6e5b9c57 --- /dev/null +++ b/braintrust-sdk/src/test/java/dev/braintrust/eval/TracedScorerEvalTest.java @@ -0,0 +1,323 @@ +package dev.braintrust.eval; + +import static org.junit.jupiter.api.Assertions.*; + +import dev.braintrust.TestHarness; +import dev.braintrust.trace.BrainstoreTrace; +import io.opentelemetry.api.trace.SpanId; +import io.opentelemetry.sdk.trace.data.SpanData; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import lombok.SneakyThrows; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TracedScorerEvalTest { + private TestHarness testHarness; + + @BeforeEach + void beforeEach() { + testHarness = TestHarness.setup(); + } + + /** + * Verifies that when a {@link TracedScorer} is in the scorer list, {@code Eval} dispatches to + * {@link TracedScorer#score(TaskResult, BrainstoreTrace)} instead of {@link + * Scorer#score(TaskResult)}. + * + *

The test scorer captures the {@link BrainstoreTrace} it receives but does NOT call {@code + * getSpans()} — this avoids making a live BTQL HTTP call and keeps the test fast and hermetic. + */ + @Test + @SneakyThrows + void evalDispatchesToTracedScorerWithTrace() { + var capturedTrace = new AtomicReference(); + var capturedTaskResult = new AtomicReference>(); + var scoreCallCount = new AtomicReference<>(0); + + var tracedScorer = + new TracedScorer() { + @Override + public String getName() { + return "trace_aware_scorer"; + } + + @Override + public List score( + TaskResult taskResult, BrainstoreTrace trace) { + capturedTrace.set(trace); + capturedTaskResult.set(taskResult); + scoreCallCount.set(scoreCallCount.get() + 1); + return List.of(new Score(getName(), 1.0)); + } + }; + + var eval = + testHarness + .braintrust() + .evalBuilder() + .name("unit-test-eval") + .cases(DatasetCase.of("hello", "world")) + .taskFunction(input -> "world") + .scorers(tracedScorer) + .build(); + + eval.run(); + + // Verify the TracedScorer overload was called (not the plain Scorer overload) + assertEquals(1, scoreCallCount.get(), "traced scorer should have been called once"); + assertNotNull( + capturedTrace.get(), "TracedScorer should receive a non-null BrainstoreTrace"); + assertNotNull(capturedTaskResult.get(), "TracedScorer should receive the TaskResult"); + assertEquals("hello", capturedTaskResult.get().datasetCase().input()); + assertEquals("world", capturedTaskResult.get().result()); + } + + /** + * Verifies that when a {@link TracedScorer} is mixed with a regular {@link Scorer}, both are + * called correctly: the traced scorer receives a {@link BrainstoreTrace}, the regular scorer + * does not. + */ + @Test + @SneakyThrows + void evalMixedScorersMaintainCorrectDispatch() { + var tracedScorerCalled = new AtomicReference<>(false); + var regularScorerCalled = new AtomicReference<>(false); + + var tracedScorer = + new TracedScorer() { + @Override + public String getName() { + return "traced"; + } + + @Override + public List score( + TaskResult taskResult, BrainstoreTrace trace) { + assertNotNull(trace); + tracedScorerCalled.set(true); + return List.of(new Score(getName(), 1.0)); + } + }; + + var regularScorer = + new Scorer() { + @Override + public String getName() { + return "regular"; + } + + @Override + public List score(TaskResult taskResult) { + regularScorerCalled.set(true); + return List.of(new Score(getName(), 0.5)); + } + }; + + var eval = + testHarness + .braintrust() + .evalBuilder() + .name("unit-test-eval") + .cases(DatasetCase.of("input", "expected")) + .taskFunction(input -> "output") + .scorers(tracedScorer, regularScorer) + .build(); + + eval.run(); + + assertTrue(tracedScorerCalled.get(), "TracedScorer should have been called"); + assertTrue(regularScorerCalled.get(), "regular Scorer should have been called"); + } + + /** + * Verifies that a {@link TracedScorer} that throws falls back to {@link + * Scorer#scoreForScorerException}, just like a regular scorer. + */ + @Test + @SneakyThrows + void evalTracedScorerExceptionFallsBackToScoreForScorerException() { + var fallbackCalled = new AtomicReference<>(false); + + var brokenTracedScorer = + new TracedScorer() { + @Override + public String getName() { + return "broken_traced"; + } + + @Override + public List score( + TaskResult taskResult, BrainstoreTrace trace) { + throw new RuntimeException("traced scorer is broken"); + } + + @Override + public List scoreForScorerException( + Exception e, TaskResult taskResult) { + fallbackCalled.set(true); + return List.of(new Score(getName(), 0.0)); + } + }; + + var eval = + testHarness + .braintrust() + .evalBuilder() + .name("unit-test-eval-scorer-error") + .cases(DatasetCase.of("input", "expected")) + .taskFunction(input -> "output") + .scorers(brokenTracedScorer) + .build(); + + // Should not throw — the broken scorer falls back gracefully + var result = eval.run(); + assertNotNull(result.getExperimentUrl()); + assertTrue(fallbackCalled.get(), "scoreForScorerException should have been called"); + + // Verify the score span has ERROR status and the fallback score of 0.0 + var spans = testHarness.awaitExportedSpans(); + var scoreSpans = + spans.stream() + .filter( + s -> { + var attrs = + s.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey + .stringKey( + "braintrust.span_attributes")); + return attrs != null && attrs.contains("\"type\":\"score\""); + }) + .toList(); + assertEquals(1, scoreSpans.size()); + var scoreSpan = scoreSpans.get(0); + assertEquals( + io.opentelemetry.api.trace.StatusCode.ERROR, + scoreSpan.getStatus().getStatusCode(), + "broken traced scorer span should be ERROR"); + assertTrue( + scoreSpan.getStatus().getDescription().contains("traced scorer is broken"), + "error description should include exception message"); + } + + /** + * Verifies that a regular {@link Scorer} continues to work correctly when mixed with a {@link + * TracedScorer} — specifically that the regular scorer's plain {@code score(TaskResult)} is + * called (not the trace overload). + */ + @Test + @SneakyThrows + void regularScorerUnaffectedByTracedScorerPresence() { + // This test ensures backward compatibility: existing scorers work fine. + var regularScoreCalled = new AtomicReference<>(false); + + Scorer regularScorer = + Scorer.of( + "exact_match", + (String expected, String result) -> { + regularScoreCalled.set(true); + return expected.equals(result) ? 1.0 : 0.0; + }); + + var tracedScorer = + new TracedScorer() { + @Override + public String getName() { + return "noop_traced"; + } + + @Override + public List score( + TaskResult taskResult, BrainstoreTrace trace) { + return List.of(new Score(getName(), 1.0)); + } + }; + + var eval = + testHarness + .braintrust() + .evalBuilder() + .name("unit-test-eval") + .cases(DatasetCase.of("fruit", "fruit")) + .taskFunction(input -> "fruit") + .scorers(regularScorer, tracedScorer) + .build(); + + eval.run(); + + assertTrue(regularScoreCalled.get(), "regular scorer should have been called"); + + var spans = testHarness.awaitExportedSpans(); + + // Should have: 1 root + 1 task + 2 score spans + var rootSpans = + spans.stream() + .filter(s -> s.getParentSpanId().equals(SpanId.getInvalid())) + .toList(); + assertEquals(1, rootSpans.size()); + + var scoreSpans = spans.stream().filter(s -> isScoreSpan(s)).toList(); + assertEquals(2, scoreSpans.size(), "one score span per scorer"); + } + + /** + * Verifies that a {@link TracedScorer} correctly receives the task exception path (no trace) + * via {@link Scorer#scoreForTaskException} when the task throws. + */ + @Test + @SneakyThrows + void tracedScorerReceivesTaskExceptionFallback() { + var taskExceptionFallbackCalled = new AtomicReference<>(false); + + var tracedScorer = + new TracedScorer() { + @Override + public String getName() { + return "traced_task_ex_scorer"; + } + + @Override + public List score( + TaskResult taskResult, BrainstoreTrace trace) { + fail("score(taskResult, trace) should not be called when task throws"); + return List.of(); + } + + @Override + public List scoreForTaskException( + Exception taskException, DatasetCase datasetCase) { + taskExceptionFallbackCalled.set(true); + return List.of(new Score(getName(), 0.0)); + } + }; + + var eval = + testHarness + .braintrust() + .evalBuilder() + .name("unit-test-eval-task-error") + .cases(DatasetCase.of("bad-input", "anything")) + .taskFunction( + input -> { + throw new RuntimeException("task always fails"); + }) + .scorers(tracedScorer) + .build(); + + var result = eval.run(); + assertNotNull(result.getExperimentUrl()); + assertTrue( + taskExceptionFallbackCalled.get(), + "scoreForTaskException should have been called when task throws"); + } + + private static boolean isScoreSpan(SpanData span) { + var attrs = + span.getAttributes() + .get( + io.opentelemetry.api.common.AttributeKey.stringKey( + "braintrust.span_attributes")); + return attrs != null && attrs.contains("\"type\":\"score\""); + } +} diff --git a/braintrust-sdk/src/test/java/dev/braintrust/trace/BrainstoreTraceTest.java b/braintrust-sdk/src/test/java/dev/braintrust/trace/BrainstoreTraceTest.java new file mode 100644 index 00000000..2e86f852 --- /dev/null +++ b/braintrust-sdk/src/test/java/dev/braintrust/trace/BrainstoreTraceTest.java @@ -0,0 +1,443 @@ +package dev.braintrust.trace; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; + +public class BrainstoreTraceTest { + + // ------------------------------------------------------------------------- + // getSpans() — lazy fetch and caching + // ------------------------------------------------------------------------- + + @Test + void getSpansReturnsFetchedSpans() { + var span = Map.of("id", "span-1", "span_attributes", Map.of("type", "llm")); + var trace = traceWithSpans(List.of(span)); + + var result = trace.getSpans(); + + assertEquals(1, result.size()); + assertEquals("span-1", result.get(0).get("id")); + } + + @Test + void getSpansIsCached() { + var callCount = new AtomicInteger(0); + var trace = + new BrainstoreTrace( + () -> { + callCount.incrementAndGet(); + return List.of( + Map.of( + "id", + "span-1", + "span_attributes", + Map.of("type", "llm"))); + }); + + trace.getSpans(); + trace.getSpans(); + trace.getSpans(); + + assertEquals(1, callCount.get(), "supplier should only be called once"); + } + + @Test + void getSpansReturnsEmptyListWhenNoSpans() { + var trace = traceWithSpans(List.of()); + assertTrue(trace.getSpans().isEmpty()); + } + + // ------------------------------------------------------------------------- + // getSpans(String spanType) — type filtering + // ------------------------------------------------------------------------- + + @Test + void getSpansByTypeFiltersCorrectly() { + var llmSpan = + Map.of("id", "llm-1", "span_attributes", Map.of("type", "llm")); + var taskSpan = + Map.of("id", "task-1", "span_attributes", Map.of("type", "task")); + var evalSpan = + Map.of("id", "eval-1", "span_attributes", Map.of("type", "eval")); + var trace = traceWithSpans(List.of(llmSpan, taskSpan, evalSpan)); + + var llmSpans = trace.getSpans("llm"); + + assertEquals(1, llmSpans.size()); + assertEquals("llm-1", llmSpans.get(0).get("id")); + } + + @Test + void getSpansByTypeReturnsEmptyWhenNoMatch() { + var taskSpan = + Map.of("id", "task-1", "span_attributes", Map.of("type", "task")); + var trace = traceWithSpans(List.of(taskSpan)); + + assertTrue(trace.getSpans("llm").isEmpty()); + } + + @Test + void getSpansByTypeHandlesSpanWithNoSpanAttributes() { + var spanWithNoAttrs = Map.of("id", "bare-span"); + var trace = traceWithSpans(List.of(spanWithNoAttrs)); + + assertTrue(trace.getSpans("llm").isEmpty()); + } + + @Test + void getSpansByTypeHandlesSpanWithNoType() { + var spanWithNoType = + Map.of("id", "no-type", "span_attributes", Map.of("name", "foo")); + var trace = traceWithSpans(List.of(spanWithNoType)); + + assertTrue(trace.getSpans("llm").isEmpty()); + } + + // ------------------------------------------------------------------------- + // getLLMConversationThread() + // + // Real brainstore span shape: + // span_id: String + // span_parents: List of parent span_id strings (null/absent for root) + // span_attributes: {type: "llm"|"task"|...} + // metrics: {start: , end: , ...} + // input: List of message maps OR JSON-encoded string of same + // output: List of choice maps [{finish_reason, message: {role, content}}, ...] + // ------------------------------------------------------------------------- + + @Test + void getLLMConversationThreadReturnsEmptyForNoLLMSpans() { + // A trace with only a task span — no LLM spans + var rootTask = taskSpan("root", null, 1.0); + var trace = traceWithSpans(List.of(rootTask)); + + assertTrue(trace.getLLMConversationThread().isEmpty()); + } + + @Test + void getLLMConversationThreadExtractsSingleTurnConversation() { + // root task → llm span with one user message, one assistant reply + var sysMsg = Map.of("role", "system", "content", "be helpful"); + var userMsg = Map.of("role", "user", "content", "strawberry"); + var assistantMsg = Map.of("role", "assistant", "content", "fruit"); + var choice = + Map.of( + "finish_reason", "stop", "index", 0, "message", assistantMsg); + + var root = taskSpan("root", null, 1.0); + var llm = llmSpan("llm1", "root", 1.1, List.of(sysMsg, userMsg), List.of(choice)); + + var trace = traceWithSpans(List.of(root, llm)); + var thread = trace.getLLMConversationThread(); + + // thread: sysMsg, userMsg, then the choice object from output + assertEquals(3, thread.size()); + assertEquals("system", thread.get(0).get("role")); + assertEquals("be helpful", thread.get(0).get("content")); + assertEquals("user", thread.get(1).get("role")); + assertEquals("strawberry", thread.get(1).get("content")); + assertEquals("stop", thread.get(2).get("finish_reason")); + assertEquals(assistantMsg, thread.get(2).get("message")); + } + + @Test + void getLLMConversationThreadDeDuplicatesSequentialMultiTurn() { + // Sequential multi-turn chat matching real trace topology (e.g. OpenCode turns): + // root (task) + // ├── turn1 (task) + // │ └── llm1 (llm) input=[sys, user1] output=[choice1] + // └── turn2 (task) + // └── llm2 (llm) input=[sys, user1, choice1, user2] output=[choice2] + // + // llm2's input is the full conversation history — sys and user1/choice1 appeared in + // llm1 already, so they must NOT be duplicated in the thread. + // Expected thread: sys, user1, choice1, user2, choice2 — 5 items, no duplicates. + var sysMsg = Map.of("role", "system", "content", "be helpful"); + var user1Msg = Map.of("role", "user", "content", "Q1"); + var asst1Msg = Map.of("role", "assistant", "content", "A1"); + var choice1 = Map.of("finish_reason", "stop", "message", asst1Msg); + + var user2Msg = Map.of("role", "user", "content", "Q2"); + var asst2Msg = Map.of("role", "assistant", "content", "A2"); + var choice2 = Map.of("finish_reason", "stop", "message", asst2Msg); + + var root = taskSpan("root", null, 1.0); + var turn1 = taskSpan("turn1", "root", 1.1); + var llm1 = llmSpan("llm1", "turn1", 1.15, List.of(sysMsg, user1Msg), List.of(choice1)); + var turn2 = taskSpan("turn2", "root", 1.2); + var llm2 = + llmSpan( + "llm2", + "turn2", + 1.25, + List.of(sysMsg, user1Msg, choice1, user2Msg), + List.of(choice2)); + + var trace = traceWithSpans(List.of(root, turn2, turn1, llm2, llm1)); // out of order + var thread = trace.getLLMConversationThread(); + + assertEquals(5, thread.size(), "system message and prior context must not be duplicated"); + assertEquals("be helpful", ((Map) thread.get(0)).get("content")); // sys (once) + assertEquals("Q1", ((Map) thread.get(1)).get("content")); // user1 + assertEquals(asst1Msg, ((Map) thread.get(2)).get("message")); // choice1 + assertEquals("Q2", ((Map) thread.get(3)).get("content")); // user2 + assertEquals(asst2Msg, ((Map) thread.get(4)).get("message")); // choice2 + } + + @Test + void getLLMConversationThreadHandlesConcurrentSubagents() { + // Mirrors the real trace: orchestrator LLM dispatches 3 parallel tool calls, + // each spawning an independent subagent with its own single LLM call. + // + // Hierarchy: + // root (task) + // └── turn (task) + // ├── orch-llm (llm) ← orchestrator, fires tool_calls + // ├── agent-a (task) ← subagent A + // │ └── llm-a (llm) + // ├── agent-b (task) ← subagent B + // │ └── llm-b (llm) + // └── agent-c (task) ← subagent C + // └── llm-c (llm) + + var orchInput = + List.of( + Map.of( + "role", "system", "content", "you are an orchestrator"), + Map.of( + "role", "user", "content", "do three things in parallel")); + var orchOutput = + List.of( + Map.of( + "finish_reason", + "stop", + "message", + Map.of( + "role", + "assistant", + "content", + "", + "tool_calls", + List.of()))); + + var inputA = List.of(Map.of("role", "user", "content", "task A")); + var outputA = + List.of( + Map.of( + "finish_reason", + "stop", + "message", + Map.of( + "role", "assistant", "content", "result A"))); + + var inputB = List.of(Map.of("role", "user", "content", "task B")); + var outputB = + List.of( + Map.of( + "finish_reason", + "stop", + "message", + Map.of( + "role", "assistant", "content", "result B"))); + + var inputC = List.of(Map.of("role", "user", "content", "task C")); + var outputC = + List.of( + Map.of( + "finish_reason", + "stop", + "message", + Map.of( + "role", "assistant", "content", "result C"))); + + var root = taskSpan("root", null, 1.0); + var turn = taskSpan("turn", "root", 1.1); + var orchLlm = llmSpan("orch-llm", "turn", 1.2, orchInput, orchOutput); + var agentA = taskSpan("agent-a", "turn", 1.3); + var llmA = llmSpan("llm-a", "agent-a", 1.4, inputA, outputA); + var agentB = taskSpan("agent-b", "turn", 1.3); // same start as agent-a (concurrent) + var llmB = llmSpan("llm-b", "agent-b", 1.4, inputB, outputB); + var agentC = taskSpan("agent-c", "turn", 1.3); + var llmC = llmSpan("llm-c", "agent-c", 1.4, inputC, outputC); + + var trace = + traceWithSpans( + List.of(root, turn, orchLlm, agentA, llmA, agentB, llmB, agentC, llmC)); + var thread = trace.getLLMConversationThread(); + + // The orchestrator's messages appear first, followed by each subagent's messages. + // Each subagent is an independent branch so their messages are NOT de-duplicated. + // + // Expected: orchInput(2) + orchOutput(1) + inputA(1) + outputA(1) + // + inputB(1) + outputB(1) + // + inputC(1) + outputC(1) = 9 items + assertEquals(9, thread.size()); + + // Orchestrator turn + assertEquals("you are an orchestrator", thread.get(0).get("content")); + assertEquals("do three things in parallel", thread.get(1).get("content")); + assertNotNull(thread.get(2).get("message")); // orch output choice + + // Subagent A + assertEquals("task A", thread.get(3).get("content")); + assertEquals("result A", ((Map) thread.get(4).get("message")).get("content")); + + // Subagent B + assertEquals("task B", thread.get(5).get("content")); + assertEquals("result B", ((Map) thread.get(6).get("message")).get("content")); + + // Subagent C + assertEquals("task C", thread.get(7).get("content")); + assertEquals("result C", ((Map) thread.get(8).get("message")).get("content")); + } + + @Test + void getLLMConversationThreadHandlesStringEncodedInput() { + // Real experiment spans encode input as a JSON string, not a raw List + var sysMsg = Map.of("role", "system", "content", "be helpful"); + var userMsg = Map.of("role", "user", "content", "hello"); + var assistantMsg = Map.of("role", "assistant", "content", "hi"); + var choice = Map.of("finish_reason", "stop", "message", assistantMsg); + + String jsonInput = + dev.braintrust.json.BraintrustJsonMapper.toJson(List.of(sysMsg, userMsg)); + + var root = taskSpan("root", null, 1.0); + // Use a mutable map to allow String input + var llmSpanMap = new java.util.LinkedHashMap(); + llmSpanMap.put("span_id", "llm1"); + llmSpanMap.put("span_parents", List.of("root")); + llmSpanMap.put("span_attributes", Map.of("type", "llm")); + llmSpanMap.put("metrics", Map.of("start", 1.1)); + llmSpanMap.put("input", jsonInput); // String, not List + llmSpanMap.put("output", List.of(choice)); + + var trace = traceWithSpans(List.of(root, llmSpanMap)); + var thread = trace.getLLMConversationThread(); + + assertEquals(3, thread.size()); + assertEquals("system", thread.get(0).get("role")); + assertEquals("user", thread.get(1).get("role")); + assertEquals(assistantMsg, thread.get(2).get("message")); + } + + @Test + void getLLMConversationThreadPrunesAutomationSubtrees() { + // Automation spans (e.g. Topics scorer) and all their descendants must be excluded. + // Mirrors the real trace where a Topics automation span has Pipeline/Chat Completion/ + // Embedding children that should not appear in the conversation thread. + var userMsg = Map.of("role", "user", "content", "hello"); + var assistantMsg = Map.of("role", "assistant", "content", "hi"); + var choice = Map.of("finish_reason", "stop", "message", assistantMsg); + + // Scorer LLM input — should NOT appear in thread + var scorerSysMsg = + Map.of("role", "system", "content", "you are an analyst"); + var scorerUserMsg = Map.of("role", "user", "content", "analyze this"); + var scorerChoice = + Map.of( + "finish_reason", + "stop", + "message", + Map.of("role", "assistant", "content", "analysis")); + + var root = taskSpan("root", null, 1.0); + var turn = taskSpan("turn", "root", 1.1); + var llm = llmSpan("llm1", "turn", 1.2, List.of(userMsg), List.of(choice)); + + // automation span at same level as turn — entire subtree should be pruned + var automation = spanWithType("automation-span", "root", "automation", 2.0); + var pipeline = spanWithType("pipeline", "automation-span", "facet", 2.1); + var scorerLlm = + llmSpan( + "scorer-llm", + "pipeline", + 2.2, + List.of(scorerSysMsg, scorerUserMsg), + List.of(scorerChoice)); + + var trace = traceWithSpans(List.of(root, turn, llm, automation, pipeline, scorerLlm)); + var thread = trace.getLLMConversationThread(); + + assertEquals(2, thread.size(), "scorer subtree must be pruned"); + assertEquals(userMsg, thread.get(0)); + assertEquals(choice, thread.get(1)); + } + + // ------------------------------------------------------------------------- + // fetchWithRetry — retry logic (via package-private constructor with custom supplier) + // ------------------------------------------------------------------------- + + @Test + void getSpansRetriesUntilFreshAndNonEmpty() { + var callCount = new AtomicInteger(0); + var results = new ArrayList>>(); + results.add(List.of()); // call 1: empty + results.add(List.of()); // call 2: empty + results.add(List.of(Map.of("id", "span-1", "span_attributes", Map.of("type", "llm")))); + + var trace = + new BrainstoreTrace( + () -> { + int idx = callCount.getAndIncrement(); + return idx < results.size() ? results.get(idx) : List.of(); + }); + + var spans = trace.getSpans(); + trace.getSpans(); // second call — should NOT invoke supplier again + assertEquals(1, callCount.get(), "supplier called exactly once (result is cached)"); + assertTrue(spans.isEmpty()); + } + + // ------------------------------------------------------------------------- + // Helpers — build span maps matching real brainstore shape + // ------------------------------------------------------------------------- + + /** A task/non-LLM span with the given id, optional parent, and start time. */ + private static Map taskSpan(String id, String parentId, double startTime) { + var m = new java.util.LinkedHashMap(); + m.put("span_id", id); + m.put("span_parents", parentId != null ? List.of(parentId) : null); + m.put("span_attributes", Map.of("type", "task")); + m.put("metrics", Map.of("start", startTime)); + return m; + } + + /** An LLM span with the given id, parent, start time, input messages, and output choices. */ + private static Map llmSpan( + String id, + String parentId, + double startTime, + List> input, + List> output) { + var m = new java.util.LinkedHashMap(); + m.put("span_id", id); + m.put("span_parents", parentId != null ? List.of(parentId) : null); + m.put("span_attributes", Map.of("type", "llm")); + m.put("metrics", Map.of("start", startTime)); + m.put("input", input); + m.put("output", output); + return m; + } + + /** A span with an arbitrary type (not "llm" or "task"). */ + private static Map spanWithType( + String id, String parentId, String type, double startTime) { + var m = new java.util.LinkedHashMap(); + m.put("span_id", id); + m.put("span_parents", parentId != null ? List.of(parentId) : null); + m.put("span_attributes", Map.of("type", type)); + m.put("metrics", Map.of("start", startTime)); + return m; + } + + private static BrainstoreTrace traceWithSpans(List> spans) { + return new BrainstoreTrace(() -> spans); + } +} diff --git a/btx/src/test/java/dev/braintrust/sdkspecimpl/SpanFetcher.java b/btx/src/test/java/dev/braintrust/sdkspecimpl/SpanFetcher.java index c08e72c8..cd96a0ae 100644 --- a/btx/src/test/java/dev/braintrust/sdkspecimpl/SpanFetcher.java +++ b/btx/src/test/java/dev/braintrust/sdkspecimpl/SpanFetcher.java @@ -1,15 +1,9 @@ package dev.braintrust.sdkspecimpl; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; import dev.braintrust.TestHarness; import dev.braintrust.VCR; +import dev.braintrust.trace.BrainstoreTrace; import io.opentelemetry.sdk.trace.data.SpanData; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.time.Duration; import java.util.List; import java.util.Map; @@ -17,90 +11,82 @@ * Fetches brainstore spans from the Braintrust API via a BTQL query. * *

Used in {@code RECORD} and {@code OFF} VCR modes where real API calls are made and spans are - * actually ingested into Braintrust. Retries with backoff until the expected number of child spans - * appears (mirrors the Python BTX framework's {@code fetch_braintrust_spans}). + * actually ingested into Braintrust. Delegates to {@link BrainstoreTrace} for retry logic. */ public class SpanFetcher { - private static final ObjectMapper MAPPER = new ObjectMapper(); - private static final TypeReference> MAP_TYPE = new TypeReference<>() {}; - - private static final int BACKOFF_SECONDS = 30; - private static final int MAX_TOTAL_WAIT_SECONDS = 600; - - private final HttpClient httpClient; private final TestHarness harness; - private final String btqlUrl; - private final String apiKey; - private final String projectId; public SpanFetcher(TestHarness harness) { this.harness = harness; - this.httpClient = HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(10)).build(); - this.btqlUrl = harness.braintrustApiBaseUrl().replaceAll("/+$", "") + "/btql"; - this.apiKey = harness.braintrustApiKey(); - this.projectId = TestHarness.defaultProjectId(); } /** - * Fetch child brainstore spans for the given trace, retrying until {@code - * numExpectedChildSpans} are available. + * Fetch brainstore spans for the given trace, retrying until all child spans are available. * * @param rootSpanId hex trace ID (from OTel span context, e.g. {@code * "e6f892e37dac9e3ef2f8906d6600d70c"}) - * @param numExpectedChildSpans number of child spans to wait for + * @param numExpectedChildSpans number of child spans to wait for (used to await in-memory OTel + * export before querying the live API) */ public List> fetch(String rootSpanId, int numExpectedChildSpans) throws Exception { + // Wait for all spans to flush through the in-memory OTel exporter first. + // +1 accounts for the root wrapper span created by SpecExecutor. List otelSpans = harness.awaitExportedSpans(numExpectedChildSpans + 1).stream() .filter(spanData -> spanData.getTraceId().equals(rootSpanId)) .toList(); - List> brainstoreSpans; List> convertedOtelSpans = SpanConverter.toBrainstoreSpans(otelSpans); + if (isReplayMode()) { // Fast path: convert the in-memory OTel spans to brainstore format locally. - brainstoreSpans = convertedOtelSpans; - } else { - // Live path: spans were actually sent to Braintrust — fetch them back via BTQL. - brainstoreSpans = fetchFromBrainstore(rootSpanId, numExpectedChildSpans); - // assert that our converted otel spans will match what is in brainstore - assertConverterMatchesBrainstore(convertedOtelSpans, brainstoreSpans, rootSpanId); + return convertedOtelSpans; } - return brainstoreSpans; - } - /** - * Fetch child brainstore spans for the given trace, retrying until {@code - * numExpectedChildSpans} are available. - * - * @param rootSpanId hex trace ID (from OTel span context, e.g. {@code - * "e6f892e37dac9e3ef2f8906d6600d70c"}) - * @param numExpectedChildSpans number of child spans to wait for - */ - private List> fetchFromBrainstore( - String rootSpanId, int numExpectedChildSpans) throws Exception { - int totalWait = 0; - LookupException lastError = null; + // Live path: spans were actually sent to Braintrust — fetch them back via BTQL. + // Use the child span IDs from in-memory OTel as the completion signal: we block + // until every one of those specific spans has appeared in the backend. + var childSpanIds = + otelSpans.stream() + .filter(s -> s.getParentSpanContext().isValid()) // exclude root wrapper + .map(s -> s.getSpanContext().getSpanId()) + .toList(); - while (true) { - try { - return fetchOnce(rootSpanId, numExpectedChildSpans); - } catch (LookupException e) { - lastError = e; - if (totalWait >= MAX_TOTAL_WAIT_SECONDS) { - break; - } - System.out.printf( - "Spans not ready yet, waiting %ds before retry (total wait: %ds)...%n", - BACKOFF_SECONDS, totalWait); - Thread.sleep(BACKOFF_SECONDS * 1000L); - totalWait += BACKOFF_SECONDS; - } - } - throw new RuntimeException( - "Timed out waiting for brainstore spans after " + MAX_TOTAL_WAIT_SECONDS + "s", - lastError); + var trace = + BrainstoreTrace.forTrace( + harness.braintrust().openApiClient(), + TestHarness.defaultProjectId(), + rootSpanId, + childSpanIds); + + // getSpans() triggers the lazy fetch + retry loop + List> allSpans = trace.getSpans(); + + // Exclude the root wrapper span (span_parents is null) and scorer spans injected by + // the backend — btx only validates the child LLM/tool spans against the spec. + List> brainstoreSpans = + allSpans.stream() + .filter( + s -> { + // Root span has no parents — skip it + Object parents = s.get("span_parents"); + if (parents == null + || (parents instanceof List l && l.isEmpty())) { + return false; + } + // Skip scorer spans injected by the backend + Object sa = s.get("span_attributes"); + if (sa instanceof Map saMap) { + return !"scorer".equals(saMap.get("purpose")); + } + return true; + }) + .toList(); + + // Cross-check that our local OTel→brainstore conversion matches the real thing. + assertConverterMatchesBrainstore(convertedOtelSpans, brainstoreSpans, rootSpanId); + return brainstoreSpans; } /** @@ -238,151 +224,6 @@ private static void assertIsSubset(Object subset, Object superset, String ctx) { } } - @SuppressWarnings("unchecked") - private List> fetchOnce(String rootSpanId, int numExpectedChildSpans) - throws Exception { - - String body = buildBtqlQuery(rootSpanId); - - HttpRequest request = - HttpRequest.newBuilder() - .uri(URI.create(btqlUrl)) - .header("Authorization", "Bearer " + apiKey) - .header("Content-Type", "application/json") - .POST(HttpRequest.BodyPublishers.ofString(body)) - .timeout(Duration.ofSeconds(30)) - .build(); - - HttpResponse response = - httpClient.send(request, HttpResponse.BodyHandlers.ofString()); - - if (response.statusCode() != 200) { - throw new LookupException( - "BTQL query failed with status " - + response.statusCode() - + ": " - + response.body()); - } - - Map result = MAPPER.readValue(response.body(), MAP_TYPE); - List> allSpans = (List>) result.get("data"); - if (allSpans == null) { - throw new LookupException("BTQL response missing 'data' field"); - } - - // Filter out scorer spans injected by the Braintrust backend - List> childSpans = - allSpans.stream() - .filter( - s -> { - Object sa = s.get("span_attributes"); - if (sa instanceof Map) { - return !"scorer".equals(((Map) sa).get("purpose")); - } - return true; - }) - .toList(); - - int actual = childSpans.size(); - if (actual == 0) { - throw new LookupException("No child spans found yet for root_span_id: " + rootSpanId); - } - if (actual < numExpectedChildSpans) { - throw new LookupException( - "Expected " - + numExpectedChildSpans - + " child spans, only found " - + actual - + " so far"); - } - if (actual > numExpectedChildSpans) { - throw new RuntimeException( - "Expected " - + numExpectedChildSpans - + " child spans but found " - + actual - + " — too many (non-retriable)"); - } - - // Retry if any span is still incomplete (output or metrics not yet ingested). - // Braintrust may ingest the span skeleton before the payload fields are indexed. - for (Map span : childSpans) { - if (span.get("output") == null && span.get("metrics") == null) { - throw new LookupException( - "Span found but output/metrics not yet ingested (span_id: " - + span.get("span_id") - + ")"); - } - } - - return childSpans; - } - - /** - * Build the BTQL query JSON string. Uses LinkedHashMap throughout because Map.of() rejects null - * values (needed for the span_parents != null filter). - */ - private String buildBtqlQuery(String rootSpanId) throws Exception { - // span_parents != null literal node (Map.of rejects nulls, so use LinkedHashMap) - Map nullLiteral = new java.util.LinkedHashMap<>(); - nullLiteral.put("op", "literal"); - nullLiteral.put("value", null); - - Map query = new java.util.LinkedHashMap<>(); - query.put("query", buildQueryNode(rootSpanId, nullLiteral)); - query.put("use_columnstore", true); - query.put("use_brainstore", true); - query.put("brainstore_realtime", true); - - return MAPPER.writeValueAsString(query); - } - - private Map buildQueryNode(String rootSpanId, Map nullLiteral) { - Map q = new java.util.LinkedHashMap<>(); - q.put("select", List.of(Map.of("op", "star"))); - q.put( - "from", - Map.of( - "op", "function", - "name", Map.of("op", "ident", "name", List.of("project_logs")), - "args", List.of(Map.of("op", "literal", "value", projectId)))); - q.put( - "filter", - Map.of( - "op", - "and", - "left", - Map.of( - "op", "eq", - "left", Map.of("op", "ident", "name", List.of("root_span_id")), - "right", Map.of("op", "literal", "value", rootSpanId)), - "right", - Map.of( - "op", - "ne", - "left", - Map.of("op", "ident", "name", List.of("span_parents")), - "right", - nullLiteral))); - q.put( - "sort", - List.of( - Map.of( - "expr", - Map.of("op", "ident", "name", List.of("created")), - "dir", - "asc"))); - q.put("limit", 1000); - return q; - } - - /** Retriable error: spans not yet available, caller should retry. */ - private static class LookupException extends Exception { - LookupException(String msg) { - super(msg); - } - } - /** Returns true when running in VCR replay mode (the default). */ private static boolean isReplayMode() { return TestHarness.getVcrMode().equals(VCR.VcrMode.REPLAY); diff --git a/examples/src/main/java/dev/braintrust/examples/TraceScoringExample.java b/examples/src/main/java/dev/braintrust/examples/TraceScoringExample.java new file mode 100644 index 00000000..5679fe4e --- /dev/null +++ b/examples/src/main/java/dev/braintrust/examples/TraceScoringExample.java @@ -0,0 +1,127 @@ +package dev.braintrust.examples; + +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import dev.braintrust.Braintrust; +import dev.braintrust.eval.*; +import dev.braintrust.instrumentation.openai.BraintrustOpenAI; +import dev.braintrust.trace.BrainstoreTrace; +import java.util.List; +import java.util.function.Function; + +/** + * Demonstrates trace scoring: a {@link TracedScorer} that inspects the intermediate LLM spans + * produced during task execution, rather than only examining the final output. + * + *

This is useful for evaluating multi-step tasks where you want to score the reasoning process + * (e.g. checking that the LLM cited sources, stayed on topic, or used the right tool calls) in + * addition to — or instead of — the final answer. + */ +public class TraceScoringExample { + + public static void main(String[] args) throws Exception { + var braintrust = Braintrust.get(); + var openTelemetry = braintrust.openTelemetryCreate(); + var openAIClient = BraintrustOpenAI.wrapOpenAI(openTelemetry, OpenAIOkHttpClient.fromEnv()); + + // Task: ask the LLM to classify a food item + var task = + (Function) + food -> { + var request = + ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_4O_MINI) + .addSystemMessage( + "Classify the given food as either 'fruit' or" + + " 'vegetable'. Return only one word.") + .addUserMessage(food) + .maxTokens(10L) + .temperature(0.0) + .build(); + return openAIClient + .chat() + .completions() + .create(request) + .choices() + .get(0) + .message() + .content() + .orElse("") + .strip() + .toLowerCase(); + }; + + // A TracedScorer that inspects the LLM span to verify a system message was included + var systemMessageChecker = + new TracedScorer() { + @Override + public String getName() { + return "system_message_present"; + } + + @Override + public List score( + TaskResult taskResult, BrainstoreTrace trace) { + // Reconstruct the full conversation thread from LLM spans + var thread = trace.getLLMConversationThread(); + + // Check that a system message was included in the conversation + boolean hasSystemMessage = + thread.stream().anyMatch(msg -> "system".equals(msg.get("role"))); + + System.out.println( + " [trace scorer] conversation thread has " + + thread.size() + + " messages, system message present: " + + hasSystemMessage); + + return List.of(new Score(getName(), hasSystemMessage ? 1.0 : 0.0)); + } + }; + + // A TracedScorer that checks the number of LLM calls made during task execution + var llmCallCounter = + new TracedScorer() { + @Override + public String getName() { + return "single_llm_call"; + } + + @Override + public List score( + TaskResult taskResult, BrainstoreTrace trace) { + var llmSpans = trace.getSpans("llm"); + int callCount = llmSpans.size(); + + System.out.println( + " [trace scorer] LLM calls made: " + + callCount + + " for input: " + + taskResult.datasetCase().input()); + + // Score 1.0 if exactly one LLM call was made, 0.0 otherwise + return List.of(new Score(getName(), callCount == 1 ? 1.0 : 0.0)); + } + }; + + var eval = + braintrust + .evalBuilder() + .name("trace-scoring-example-" + System.currentTimeMillis()) + .cases(DatasetCase.of("strawberry", "fruit")) + .taskFunction(task) + .scorers( + // A regular scorer (final output only) + Scorer.of( + "exact_match", + (expected, result) -> expected.equals(result) ? 1.0 : 0.0), + // Trace-aware scorers (inspect intermediate LLM spans) + systemMessageChecker, + llmCallCounter) + .build(); + + var result = eval.run(); + System.out.println("\n\n" + result.createReportString()); + } +}