diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java index 309e0da9558..c9e623b4415 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java @@ -163,14 +163,6 @@ static String recordMethodName(String fullMethodName, boolean isGeneratedMethod) return isGeneratedMethod ? fullMethodName : "other"; } - private static Context otelContextWithBaggage() { - Baggage baggage = BAGGAGE_KEY.get(); - if (baggage == null) { - return Context.current(); - } - return Context.current().with(baggage); - } - private static final class ClientTracer extends ClientStreamTracer { @Nullable private static final AtomicLongFieldUpdater outboundWireSizeUpdater; @Nullable private static final AtomicLongFieldUpdater inboundWireSizeUpdater; @@ -286,7 +278,6 @@ public void streamClosed(Status status) { } void recordFinishedAttempt() { - Context otelContext = otelContextWithBaggage(); AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder() .put(METHOD_KEY, fullMethodName) .put(TARGET_KEY, target) @@ -316,15 +307,15 @@ void recordFinishedAttempt() { if (module.resource.clientAttemptDurationCounter() != null ) { module.resource.clientAttemptDurationCounter() - .record(attemptNanos * SECONDS_PER_NANO, attribute, otelContext); + .record(attemptNanos * SECONDS_PER_NANO, attribute, attemptsState.otelContext); } if (module.resource.clientTotalSentCompressedMessageSizeCounter() != null) { module.resource.clientTotalSentCompressedMessageSizeCounter() - .record(outboundWireSize, attribute, otelContext); + .record(outboundWireSize, attribute, attemptsState.otelContext); } if (module.resource.clientTotalReceivedCompressedMessageSizeCounter() != null) { module.resource.clientTotalReceivedCompressedMessageSizeCounter() - .record(inboundWireSize, attribute, otelContext); + .record(inboundWireSize, attribute, attemptsState.otelContext); } } } @@ -339,6 +330,7 @@ static final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory private boolean callEnded; private final String fullMethodName; private final List callPlugins; + private final Context otelContext; private Status status; private long retryDelayNanos; private long callLatencyNanos; @@ -356,11 +348,12 @@ static final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory String target, CallOptions callOptions, String fullMethodName, - List callPlugins) { + List callPlugins, Context otelContext) { this.module = checkNotNull(module, "module"); this.target = checkNotNull(target, "target"); this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); this.callPlugins = checkNotNull(callPlugins, "callPlugins"); + this.otelContext = checkNotNull(otelContext, "otelContext"); this.attemptDelayStopwatch = module.stopwatchSupplier.get(); this.callStopWatch = module.stopwatchSupplier.get().start(); @@ -375,7 +368,7 @@ static final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory // Record here in case mewClientStreamTracer() would never be called. if (module.resource.clientAttemptCountCounter() != null) { - module.resource.clientAttemptCountCounter().add(1, attribute); + module.resource.clientAttemptCountCounter().add(1, attribute, otelContext); } } @@ -404,7 +397,7 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metada } io.opentelemetry.api.common.Attributes attribute = builder.build(); if (module.resource.clientAttemptCountCounter() != null) { - module.resource.clientAttemptCountCounter().add(1, attribute); + module.resource.clientAttemptCountCounter().add(1, attribute, otelContext); } } if (info.isTransparentRetry()) { @@ -467,7 +460,6 @@ void callEnded(Status status, CallOptions callOptions) { } void recordFinishedCall(CallOptions callOptions) { - Context otelContext = otelContextWithBaggage(); if (attemptsPerCall.get() == 0) { ClientTracer tracer = newClientTracer(null); tracer.attemptNanos = attemptDelayStopwatch.elapsed(TimeUnit.NANOSECONDS); @@ -569,6 +561,7 @@ private static final class ServerTracer extends ServerStreamTracer { private final OpenTelemetryMetricsModule module; private final String fullMethodName; private final List streamPlugins; + private Context otelContext = Context.root(); private volatile boolean isGeneratedMethod; private volatile int streamClosed; private final Stopwatch stopwatch; @@ -583,6 +576,17 @@ private static final class ServerTracer extends ServerStreamTracer { this.stopwatch = module.stopwatchSupplier.get().start(); } + @Override + public io.grpc.Context filterContext(io.grpc.Context context) { + Baggage baggage = BAGGAGE_KEY.get(context); + if (baggage != null) { + otelContext = Context.current().with(baggage); + } else { + otelContext = Context.current(); + } + return context; + } + @Override public void serverCallStarted(ServerCallInfo callInfo) { // Only record method name as an attribute if isSampledToLocalTracing is set to true, @@ -590,12 +594,13 @@ public void serverCallStarted(ServerCallInfo callInfo) { // created methods result in high cardinality metrics. boolean isSampledToLocalTracing = callInfo.getMethodDescriptor().isSampledToLocalTracing(); isGeneratedMethod = isSampledToLocalTracing; + io.opentelemetry.api.common.Attributes attribute = io.opentelemetry.api.common.Attributes.of( METHOD_KEY, recordMethodName(fullMethodName, isSampledToLocalTracing)); if (module.resource.serverCallCountCounter() != null) { - module.resource.serverCallCountCounter().add(1, attribute); + module.resource.serverCallCountCounter().add(1, attribute, otelContext); } } @@ -627,7 +632,6 @@ public void inboundWireSize(long bytes) { */ @Override public void streamClosed(Status status) { - Context otelContext = otelContextWithBaggage(); if (streamClosedUpdater != null) { if (streamClosedUpdater.getAndSet(this, 1) != 0) { return; @@ -678,7 +682,8 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata } streamPlugins = Collections.unmodifiableList(streamPluginsMutable); } - return new ServerTracer(OpenTelemetryMetricsModule.this, fullMethodName, streamPlugins); + return new ServerTracer(OpenTelemetryMetricsModule.this, fullMethodName, + streamPlugins); } } @@ -716,7 +721,7 @@ public ClientCall interceptCall( final CallAttemptsTracerFactory tracerFactory = new CallAttemptsTracerFactory( OpenTelemetryMetricsModule.this, target, callOptions, recordMethodName(method.getFullMethodName(), method.isSampledToLocalTracing()), - callPlugins); + callPlugins, Context.current()); ClientCall call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall(call) { @@ -739,3 +744,4 @@ public void onClose(Status status, Metadata trailers) { } } } + diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java index 98fdbffc82b..14139b8e439 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java @@ -25,9 +25,11 @@ import static java.util.Collections.emptyList; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import com.google.common.collect.ImmutableMap; @@ -40,13 +42,16 @@ import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.KnownLength; +import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.Server; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerServiceDefinition; import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer.ServerCallInfo; +import io.grpc.ServiceDescriptor; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.inprocess.InProcessChannelBuilder; @@ -55,17 +60,19 @@ import io.grpc.opentelemetry.GrpcOpenTelemetry.TargetFilter; import io.grpc.opentelemetry.OpenTelemetryMetricsModule.CallAttemptsTracerFactory; import io.grpc.opentelemetry.internal.OpenTelemetryConstants; -import io.grpc.stub.MetadataUtils; -import io.grpc.stub.StreamObserver; +import io.grpc.stub.ClientCalls; +import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.GrpcServerRule; -import io.grpc.testing.protobuf.SimpleRequest; -import io.grpc.testing.protobuf.SimpleResponse; -import io.grpc.testing.protobuf.SimpleServiceGrpc; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.baggage.Baggage; +import io.opentelemetry.api.baggage.propagation.W3CBaggagePropagator; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.metrics.DoubleHistogram; import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.common.InstrumentationScopeInfo; import io.opentelemetry.sdk.metrics.data.MetricData; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; @@ -165,6 +172,8 @@ public String parse(InputStream stream) { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Rule + public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + @Rule public final GrpcServerRule grpcServerRule = new GrpcServerRule().directExecutor(); @Rule public final OpenTelemetryRule openTelemetryTesting = OpenTelemetryRule.create(); @@ -174,14 +183,9 @@ public String parse(InputStream stream) { private ServerCall.Listener mockServerCallListener; @Captor private ArgumentCaptor statusCaptor; - @Mock - private DoubleHistogram mockServerCallDurationHistogram; - @Captor - private ArgumentCaptor contextCaptor; - private io.grpc.Server server; - private io.grpc.ManagedChannel channel; - private OpenTelemetryMetricsResource resource; - private final String serverName = "E2ETestServer-" + Math.random(); + + private Server server; + private ManagedChannel channel; private final FakeClock fakeClock = new FakeClock(); private final MethodDescriptor method = @@ -201,9 +205,7 @@ public String parse(InputStream stream) { public void setUp() throws Exception { testMeter = openTelemetryTesting.getOpenTelemetry() .getMeter(OpenTelemetryConstants.INSTRUMENTATION_SCOPE); - resource = OpenTelemetryMetricsResource.builder() - .serverCallDurationCounter(mockServerCallDurationHistogram) - .build(); + } @After @@ -279,8 +281,8 @@ public void clientBasicMetrics() { enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory( - module, target, CALL_OPTIONS, method.getFullMethodName(), emptyList()); + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); Metadata headers = new Metadata(); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); @@ -447,8 +449,8 @@ public void clientBasicMetrics_withRetryMetricsEnabled_shouldRecordZeroOrBeAbsen enabledMetrics, disableDefaultMetrics); OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory( - module, target, CALL_OPTIONS, method.getFullMethodName(), emptyList()); + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -516,7 +518,7 @@ public void recordAttemptMetrics() { OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, - method.getFullMethodName(), emptyList()); + method.getFullMethodName(), emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -973,7 +975,7 @@ public void recordAttemptMetrics_withRetryMetricsEnabled() { OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, - method.getFullMethodName(), emptyList()); + method.getFullMethodName(), emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -1061,7 +1063,7 @@ public void recordAttemptMetrics_withHedgedCalls() { OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, - method.getFullMethodName(), emptyList()); + method.getFullMethodName(), emptyList(), Context.root()); // Create a StreamInfo specifically for hedged attempts final ClientStreamTracer.StreamInfo hedgedStreamInfo = @@ -1142,7 +1144,7 @@ public void clientStreamNeverCreatedStillRecordMetrics() { OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(resource); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = new OpenTelemetryMetricsModule.CallAttemptsTracerFactory(module, target, CALL_OPTIONS, - method.getFullMethodName(), emptyList()); + method.getFullMethodName(), emptyList(), Context.root()); fakeClock.forwardTime(3000, MILLISECONDS); Status status = Status.DEADLINE_EXCEEDED.withDescription("5 seconds"); callAttemptsTracerFactory.callEnded(status, CALL_OPTIONS); @@ -1248,10 +1250,11 @@ public void clientLocalityMetrics_present() { OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), emptyList()); + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), + emptyList()); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory( - module, target, CALL_OPTIONS, method.getFullMethodName(), emptyList()); + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -1317,10 +1320,11 @@ public void clientLocalityMetrics_missing() { OpenTelemetryMetricsResource resource = GrpcOpenTelemetry.createMetricInstruments(testMeter, enabledMetricsMap, disableDefaultMetrics); OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), emptyList()); + fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.locality"), + emptyList()); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory( - module, target, CALL_OPTIONS, method.getFullMethodName(), emptyList()); + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -1385,8 +1389,8 @@ public void clientBackendServiceMetrics_present() { fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.backend_service"), emptyList()); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory( - module, target, CALL_OPTIONS, method.getFullMethodName(), emptyList()); + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -1455,8 +1459,8 @@ public void clientBackendServiceMetrics_missing() { fakeClock.getStopwatchSupplier(), resource, Arrays.asList("grpc.lb.backend_service"), emptyList()); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = - new CallAttemptsTracerFactory( - module, target, CALL_OPTIONS, method.getFullMethodName(), emptyList()); + new CallAttemptsTracerFactory(module, target, CALL_OPTIONS, method.getFullMethodName(), + emptyList(), Context.root()); ClientStreamTracer tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); @@ -1532,7 +1536,7 @@ public void customLabel_present() { emptyList()); OpenTelemetryMetricsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = new CallAttemptsTracerFactory( - module, target, callOptions, method.getFullMethodName(), emptyList()); + module, target, callOptions, method.getFullMethodName(), emptyList(), Context.root()); ClientStreamTracer.StreamInfo streamInfo = STREAM_INFO.toBuilder().setCallOptions(callOptions).build(); @@ -1730,44 +1734,6 @@ public void serverBasicMetrics() { } - @Test - public void serverBaggagePropagationToMetrics() { - // 1. Create module and tracer factory using the mock resource - OpenTelemetryMetricsModule module = new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource, emptyList(), emptyList()); - ServerStreamTracer.Factory tracerFactory = module.getServerTracerFactory(); - ServerStreamTracer tracer = - tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata()); - - // 2. Define the test baggage and gRPC context - Baggage testBaggage = Baggage.builder() - .put("user-id", "67") - .build(); - - // This simulates the context that the Tracing module would have created - io.grpc.Context grpcContext = io.grpc.Context.current() - .withValue(OpenTelemetryConstants.BAGGAGE_KEY, testBaggage); - - // 3. Attach the gRPC context, trigger metric recording, and detach - io.grpc.Context previousContext = grpcContext.attach(); - try { - tracer.streamClosed(Status.OK); - } finally { - grpcContext.detach(previousContext); - } - - // 4. Verify the record call and capture the OTel Context - verify(mockServerCallDurationHistogram).record( - anyDouble(), - any(io.opentelemetry.api.common.Attributes.class), - contextCaptor.capture()); - - // 5. Assert on the captured OTel Context - io.opentelemetry.context.Context capturedOtelContext = contextCaptor.getValue(); - Baggage capturedBaggage = Baggage.fromContext(capturedOtelContext); - - assertEquals("67", capturedBaggage.getEntryValue("user-id")); - } @Test public void targetAttributeFilter_notSet_usesOriginalTarget() { @@ -1909,7 +1875,8 @@ private OpenTelemetryMetricsModule newOpenTelemetryMetricsModule( private OpenTelemetryMetricsModule newOpenTelemetryMetricsModule( OpenTelemetryMetricsResource resource, TargetFilter filter) { return new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource, emptyList(), emptyList(), filter); + fakeClock.getStopwatchSupplier(), resource, emptyList(), emptyList(), + filter); } static class CallInfo extends ServerCallInfo { @@ -1944,67 +1911,128 @@ public String getAuthority() { } @Test - public void serverBaggagePropagation_EndToEnd() throws Exception { - // 1. Create Both Modules - OpenTelemetry otel = openTelemetryTesting.getOpenTelemetry(); - OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(otel); - OpenTelemetryMetricsModule metricsModule = new OpenTelemetryMetricsModule( - fakeClock.getStopwatchSupplier(), resource, emptyList(), emptyList()); - - // 2. Create Server with *both* tracer factories - server = InProcessServerBuilder.forName(serverName) - .addService(new SimpleServiceImpl()) // <-- Uses the helper class below - .addStreamTracerFactory(tracingModule.getServerTracerFactory()) - .addStreamTracerFactory(metricsModule.getServerTracerFactory()) - .build() - .start(); + public void serverMetrics_recordsBaggage() { + DoubleHistogram mockDurationHistogram = mock(DoubleHistogram.class); + OpenTelemetryMetricsResource mockResource = OpenTelemetryMetricsResource.builder() + .serverCallDurationCounter(mockDurationHistogram) + .build(); - // 3. Create Client Channel - channel = InProcessChannelBuilder.forName(serverName).directExecutor().build(); + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(mockResource); + ServerStreamTracer.Factory tracerFactory = module.getServerTracerFactory(); - // 4. Manually create baggage headers - Metadata headers = new Metadata(); - headers.put(Metadata.Key.of("baggage", Metadata.ASCII_STRING_MARSHALLER), - "choice=red_pill_or_blue_pill"); + Baggage baggage = Baggage.builder() + .put("baggage-key-1", "baggage-val-1") + .build(); - // 5. Make the gRPC call with these headers - ClientInterceptor headerAttachingInterceptor = - MetadataUtils.newAttachHeadersInterceptor(headers); + io.grpc.Context grpcContext = io.grpc.Context.ROOT + .withValue(OpenTelemetryConstants.BAGGAGE_KEY, baggage); + io.grpc.Context previous = grpcContext.attach(); - // Now, create the stub and apply that interceptor - SimpleServiceGrpc.SimpleServiceBlockingStub stub = - SimpleServiceGrpc.newBlockingStub(channel) - .withInterceptors(headerAttachingInterceptor); + ServerStreamTracer tracer; + try { + tracer = tracerFactory.newServerStreamTracer( + method.getFullMethodName(), new Metadata()); + tracer.filterContext(grpcContext); + tracer.serverCallStarted( + new CallInfo<>(method, Attributes.EMPTY, null)); + } finally { + grpcContext.detach(previous); + } - // Use the imported SimpleRequest - stub.unaryRpc(SimpleRequest.getDefaultInstance()); + try (io.opentelemetry.context.Scope scope = Context.root().makeCurrent()) { + tracer.streamClosed(Status.CANCELLED); + } - // 6. Verify the Mock - verify(mockServerCallDurationHistogram).record( + ArgumentCaptor contextCaptor = ArgumentCaptor.forClass(Context.class); + verify(mockDurationHistogram).record( anyDouble(), - any(io.opentelemetry.api.common.Attributes.class), + any(), contextCaptor.capture()); - // 7. Assert on the captured Context - io.opentelemetry.context.Context capturedOtelContext = contextCaptor.getValue(); - Baggage capturedBaggage = Baggage.fromContext(capturedOtelContext); - - assertEquals("red_pill_or_blue_pill", capturedBaggage.getEntryValue("choice")); + Baggage capturedBaggage = Baggage.fromContext(contextCaptor.getValue()); + assertNotNull("Captured context should have baggage", capturedBaggage); + assertEquals( + "baggage-val-1", capturedBaggage.getEntryValue("baggage-key-1")); } + @Test + public void serverMetrics_recordsBaggage_endToEnd() throws Exception { + DoubleHistogram mockDurationHistogram = mock(DoubleHistogram.class); + OpenTelemetryMetricsResource mockResource = OpenTelemetryMetricsResource.builder() + .serverCallDurationCounter(mockDurationHistogram) + .build(); + + OpenTelemetry openTelemetry = OpenTelemetrySdk + .builder() + .setPropagators(ContextPropagators.create( + W3CBaggagePropagator.getInstance())) + .build(); + + OpenTelemetryMetricsModule module = newOpenTelemetryMetricsModule(mockResource); + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(openTelemetry); + + String serverName = InProcessServerBuilder.generateName(); + InProcessServerBuilder serverBuilder = InProcessServerBuilder + .forName(serverName).directExecutor(); + + serverBuilder.addStreamTracerFactory(tracingModule.getServerTracerFactory()); + serverBuilder.intercept(tracingModule.getServerSpanPropagationInterceptor()); + serverBuilder.addStreamTracerFactory(module.getServerTracerFactory()); + + serverBuilder.addService(ServerServiceDefinition.builder( + ServiceDescriptor.newBuilder("package1.service2") + .addMethod(method) + .build()) + .addMethod(method, new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + call.sendHeaders(new Metadata()); + call.sendMessage("response"); + call.close(Status.OK, new Metadata()); + return new ServerCall.Listener() { + }; + } + }).build()); + grpcCleanup.register(serverBuilder.build().start()); + + InProcessChannelBuilder channelBuilder = InProcessChannelBuilder + .forName(serverName).directExecutor(); + channelBuilder.intercept(tracingModule.getClientInterceptor()); + channelBuilder.intercept(module.getClientInterceptor(serverName)); + Channel channel = grpcCleanup.register(channelBuilder.intercept(new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return next.newCall(method, callOptions); + } + }).build()); + + Baggage baggage = Baggage.builder() + .put("baggage-key-1", "baggage-val-1") + .build(); + + Context otelContext = Context.root().with(baggage); + + try (Scope scope = otelContext.makeCurrent()) { + ClientCalls.blockingUnaryCall(channel, + method, CallOptions.DEFAULT, "request"); + } + + ArgumentCaptor contextCaptor = ArgumentCaptor.forClass(Context.class); + verify(mockDurationHistogram).record( + anyDouble(), + any(), + contextCaptor.capture()); + + Baggage capturedBaggage = Baggage.fromContext(contextCaptor.getValue()); + assertNotNull("Captured context should have baggage", capturedBaggage); + assertEquals( + "baggage-val-1", capturedBaggage.getEntryValue("baggage-key-1")); + } + private static List sortByName(List metrics) { metrics.sort((m1, m2) -> m1.getName().compareTo(m2.getName())); return metrics; } - - /** - * A simple service implementation for the E2E test. - */ - private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { - @Override - public void unaryRpc(SimpleRequest request, StreamObserver responseObserver) { - responseObserver.onNext(SimpleResponse.getDefaultInstance()); - responseObserver.onCompleted(); - } - } }