diff --git a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java index e08ea0fab43..5fd8ec5526e 100644 --- a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java +++ b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java @@ -91,6 +91,7 @@ class LoadBalancerConfigFactory { static final String SHUFFLE_ADDRESS_LIST_FIELD_NAME = "shuffleAddressList"; static final String ERROR_UTILIZATION_PENALTY = "errorUtilizationPenalty"; + static final String METRIC_NAMES_FOR_COMPUTING_UTILIZATION = "metricNamesForComputingUtilization"; /** * Factory method for creating a new {link LoadBalancerConfigConverter} for a given xDS {@link @@ -134,11 +135,9 @@ class LoadBalancerConfigFactory { * the given config values. */ private static ImmutableMap buildWrrConfig(String blackoutPeriod, - String weightExpirationPeriod, - String oobReportingPeriod, - Boolean enableOobLoadReport, - String weightUpdatePeriod, - Float errorUtilizationPenalty) { + String weightExpirationPeriod, String oobReportingPeriod, Boolean enableOobLoadReport, + String weightUpdatePeriod, Float errorUtilizationPenalty, + ImmutableList metricNamesForComputingUtilization) { ImmutableMap.Builder configBuilder = ImmutableMap.builder(); if (blackoutPeriod != null) { configBuilder.put(BLACK_OUT_PERIOD, blackoutPeriod); @@ -158,6 +157,10 @@ class LoadBalancerConfigFactory { if (errorUtilizationPenalty != null) { configBuilder.put(ERROR_UTILIZATION_PENALTY, errorUtilizationPenalty); } + if (metricNamesForComputingUtilization != null + && !metricNamesForComputingUtilization.isEmpty()) { + configBuilder.put(METRIC_NAMES_FOR_COMPUTING_UTILIZATION, metricNamesForComputingUtilization); + } return ImmutableMap.of(WeightedRoundRobinLoadBalancerProvider.SCHEME, configBuilder.buildOrThrow()); } @@ -284,7 +287,7 @@ static class LoadBalancingPolicyConverter { } private static ImmutableMap convertWeightedRoundRobinConfig( - ClientSideWeightedRoundRobin wrr) throws ResourceInvalidException { + ClientSideWeightedRoundRobin wrr) throws ResourceInvalidException { try { return buildWrrConfig( wrr.hasBlackoutPeriod() ? Durations.toString(wrr.getBlackoutPeriod()) : null, @@ -293,7 +296,8 @@ static class LoadBalancingPolicyConverter { wrr.hasOobReportingPeriod() ? Durations.toString(wrr.getOobReportingPeriod()) : null, wrr.hasEnableOobLoadReport() ? wrr.getEnableOobLoadReport().getValue() : null, wrr.hasWeightUpdatePeriod() ? Durations.toString(wrr.getWeightUpdatePeriod()) : null, - wrr.hasErrorUtilizationPenalty() ? wrr.getErrorUtilizationPenalty().getValue() : null); + wrr.hasErrorUtilizationPenalty() ? wrr.getErrorUtilizationPenalty().getValue() : null, + ImmutableList.copyOf(wrr.getMetricNamesForComputingUtilizationList())); } catch (IllegalArgumentException ex) { throw new ResourceInvalidException("Invalid duration in weighted round robin config: " + ex.getMessage()); diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 6cf3189d587..fc270995b3e 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -40,6 +40,7 @@ import io.grpc.services.MetricReport; import io.grpc.util.ForwardingSubchannel; import io.grpc.util.MultiChildLoadBalancer; +import io.grpc.xds.internal.MetricReportUtils; import io.grpc.xds.orca.OrcaOobUtil; import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.orca.OrcaPerRequestUtil; @@ -49,6 +50,7 @@ import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.OptionalDouble; import java.util.Random; import java.util.Set; import java.util.concurrent.ScheduledExecutorService; @@ -87,6 +89,9 @@ * See related documentation: https://cloud.google.com/service-mesh/legacy/load-balancing-apis/proxyless-configure-advanced-traffic-management#custom-lb-config */ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { + @VisibleForTesting + static boolean enableCustomConfig = + Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_WRR_CUSTOM_METRICS")); private static final LongCounterMetricInstrument RR_FALLBACK_COUNTER; private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER; @@ -189,7 +194,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { this.backendService = ""; } config = - (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); if (weightUpdateTimer != null && weightUpdateTimer.isPending()) { weightUpdateTimer.cancel(); @@ -236,7 +241,8 @@ protected void updateOverallBalancingState() { private SubchannelPicker createReadyPicker(Collection activeList) { WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), - config.enableOobLoadReport, config.errorUtilizationPenalty, sequence); + config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, + config.metricNamesForComputingUtilization); updateWeight(picker); return picker; } @@ -325,12 +331,16 @@ public void addSubchannel(WrrSubchannel wrrSubchannel) { subchannels.add(wrrSubchannel); } - public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) { + public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty, + ImmutableList metricNamesForComputingUtilization) { if (orcaReportListener != null - && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) { + && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty + && Objects.equals(orcaReportListener.metricNamesForComputingUtilization, + metricNamesForComputingUtilization)) { return orcaReportListener; } - orcaReportListener = new OrcaReportListener(errorUtilizationPenalty); + orcaReportListener = + new OrcaReportListener(errorUtilizationPenalty, metricNamesForComputingUtilization); return orcaReportListener; } @@ -355,18 +365,19 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { private final float errorUtilizationPenalty; + private final ImmutableList metricNamesForComputingUtilization; - OrcaReportListener(float errorUtilizationPenalty) { + OrcaReportListener(float errorUtilizationPenalty, + ImmutableList metricNamesForComputingUtilization) { this.errorUtilizationPenalty = errorUtilizationPenalty; + this.metricNamesForComputingUtilization = metricNamesForComputingUtilization; } @Override public void onLoadReport(MetricReport report) { + double utilization = getUtilization(report, metricNamesForComputingUtilization); + double newWeight = 0; - // Prefer application utilization and fallback to CPU utilization if unset. - double utilization = - report.getApplicationUtilization() > 0 ? report.getApplicationUtilization() - : report.getCpuUtilization(); if (utilization > 0 && report.getQps() > 0) { double penalty = 0; if (report.getEps() > 0 && errorUtilizationPenalty > 0) { @@ -383,6 +394,37 @@ public void onLoadReport(MetricReport report) { lastUpdated = ticker.nanoTime(); weight = newWeight; } + + /** + * Returns the utilization value computed from the specified metric names. If the application + * utilization is present and valid, it is returned. Otherwise, the maximum of the custom + * metrics specified is returned. If none of the custom metrics are present, the CPU + * utilization is returned. + */ + private double getUtilization(MetricReport report, ImmutableList metricNames) { + double appUtil = report.getApplicationUtilization(); + if (appUtil > 0) { + return appUtil; + } + return getCustomMetricUtilization(report, metricNames) + .orElse(report.getCpuUtilization()); + } + + /** + * Returns the maximum utilization value among the specified metric names. + * Returns OptionalDouble.empty() if NONE of the specified metrics are present in the report, + * or if all present metrics are NaN. + * Returns OptionalDouble.of(maxUtil) if at least one non-NaN metric is present. + */ + private OptionalDouble getCustomMetricUtilization(MetricReport report, + ImmutableList metricNames) { + return metricNames.stream() + .map(name -> MetricReportUtils.getMetric(report, name)) + .filter(OptionalDouble::isPresent) + .mapToDouble(OptionalDouble::getAsDouble) + .filter(d -> !Double.isNaN(d) && d > 0) + .max(); + } } } @@ -403,10 +445,10 @@ private void createAndApplyOrcaListeners() { for (WrrSubchannel weightedSubchannel : wChild.subchannels) { if (config.enableOobLoadReport) { OrcaOobUtil.setListener(weightedSubchannel, - wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty), + wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty, + config.metricNamesForComputingUtilization), OrcaOobUtil.OrcaReportingConfig.newBuilder() - .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS) - .build()); + .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS).build()); } else { OrcaOobUtil.setListener(weightedSubchannel, null, null); } @@ -473,7 +515,8 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker { private volatile StaticStrideScheduler scheduler; WeightedRoundRobinPicker(List children, boolean enableOobLoadReport, - float errorUtilizationPenalty, AtomicInteger sequence) { + float errorUtilizationPenalty, AtomicInteger sequence, + ImmutableList metricNamesForComputingUtilization) { checkNotNull(children, "children"); Preconditions.checkArgument(!children.isEmpty(), "empty child list"); this.children = children; @@ -482,7 +525,8 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker { for (ChildLbState child : children) { WeightedChildLbState wChild = (WeightedChildLbState) child; pickers.add(wChild.getCurrentPicker()); - reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty)); + reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty, + metricNamesForComputingUtilization)); } this.pickers = pickers; this.reportListeners = reportListeners; @@ -565,11 +609,11 @@ public boolean equals(Object o) { * The Static Stride Scheduler is an implementation of an earliest deadline first (EDF) scheduler * in which each object's deadline is the multiplicative inverse of the object's weight. *

- * The way in which this is implemented is through a static stride scheduler. + * The way in which this is implemented is through a static stride scheduler. * The Static Stride Scheduler works by iterating through the list of subchannel weights - * and using modular arithmetic to proportionally distribute picks, favoring entries - * with higher weights. It is based on the observation that the intended sequence generated - * from an EDF scheduler is a periodic one that can be achieved through modular arithmetic. + * and using modular arithmetic to proportionally distribute picks, favoring entries + * with higher weights. It is based on the observation that the intended sequence generated + * from an EDF scheduler is a periodic one that can be achieved through modular arithmetic. * The Static Stride Scheduler is more performant than other implementations of the EDF * Scheduler, as it removes the need for a priority queue (and thus mutex locks). *

@@ -720,23 +764,23 @@ static final class WeightedRoundRobinLoadBalancerConfig { final long oobReportingPeriodNanos; final long weightUpdatePeriodNanos; final float errorUtilizationPenalty; + final ImmutableList metricNamesForComputingUtilization; public static Builder newBuilder() { return new Builder(); } private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos, - long weightExpirationPeriodNanos, - boolean enableOobLoadReport, - long oobReportingPeriodNanos, - long weightUpdatePeriodNanos, - float errorUtilizationPenalty) { + long weightExpirationPeriodNanos, boolean enableOobLoadReport, long oobReportingPeriodNanos, + long weightUpdatePeriodNanos, float errorUtilizationPenalty, + ImmutableList metricNamesForComputingUtilization) { this.blackoutPeriodNanos = blackoutPeriodNanos; this.weightExpirationPeriodNanos = weightExpirationPeriodNanos; this.enableOobLoadReport = enableOobLoadReport; this.oobReportingPeriodNanos = oobReportingPeriodNanos; this.weightUpdatePeriodNanos = weightUpdatePeriodNanos; this.errorUtilizationPenalty = errorUtilizationPenalty; + this.metricNamesForComputingUtilization = metricNamesForComputingUtilization; } @Override @@ -751,27 +795,26 @@ public boolean equals(Object o) { && this.oobReportingPeriodNanos == that.oobReportingPeriodNanos && this.weightUpdatePeriodNanos == that.weightUpdatePeriodNanos // Float.compare considers NaNs equal - && Float.compare(this.errorUtilizationPenalty, that.errorUtilizationPenalty) == 0; + && Float.compare(this.errorUtilizationPenalty, that.errorUtilizationPenalty) == 0 + && Objects.equals(this.metricNamesForComputingUtilization, + that.metricNamesForComputingUtilization); } @Override public int hashCode() { - return Objects.hash( - blackoutPeriodNanos, - weightExpirationPeriodNanos, - enableOobLoadReport, - oobReportingPeriodNanos, - weightUpdatePeriodNanos, - errorUtilizationPenalty); + return Objects.hash(blackoutPeriodNanos, weightExpirationPeriodNanos, enableOobLoadReport, + oobReportingPeriodNanos, weightUpdatePeriodNanos, errorUtilizationPenalty, + metricNamesForComputingUtilization); } static final class Builder { long blackoutPeriodNanos = 10_000_000_000L; // 10s - long weightExpirationPeriodNanos = 180_000_000_000L; //3min + long weightExpirationPeriodNanos = 180_000_000_000L; // 3min boolean enableOobLoadReport = false; long oobReportingPeriodNanos = 10_000_000_000L; // 10s long weightUpdatePeriodNanos = 1_000_000_000L; // 1s float errorUtilizationPenalty = 1.0F; + ImmutableList metricNamesForComputingUtilization = ImmutableList.of(); private Builder() { @@ -809,10 +852,17 @@ Builder setErrorUtilizationPenalty(float errorUtilizationPenalty) { return this; } + Builder setMetricNamesForComputingUtilization( + List metricNamesForComputingUtilization) { + this.metricNamesForComputingUtilization = + ImmutableList.copyOf(metricNamesForComputingUtilization); + return this; + } + WeightedRoundRobinLoadBalancerConfig build() { return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos, - weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos, - weightUpdatePeriodNanos, errorUtilizationPenalty); + weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos, + weightUpdatePeriodNanos, errorUtilizationPenalty, metricNamesForComputingUtilization); } } } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java index 433ea34b857..40b007f7eb8 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java @@ -26,6 +26,7 @@ import io.grpc.Status; import io.grpc.internal.JsonUtil; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; +import java.util.List; import java.util.Map; /** @@ -73,14 +74,16 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawConfig) { Long blackoutPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "blackoutPeriod"); Long weightExpirationPeriodNanos = - JsonUtil.getStringAsDuration(rawConfig, "weightExpirationPeriod"); + JsonUtil.getStringAsDuration(rawConfig, "weightExpirationPeriod"); Long oobReportingPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "oobReportingPeriod"); Boolean enableOobLoadReport = JsonUtil.getBoolean(rawConfig, "enableOobLoadReport"); Long weightUpdatePeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "weightUpdatePeriod"); Float errorUtilizationPenalty = JsonUtil.getNumberAsFloat(rawConfig, "errorUtilizationPenalty"); + List metricNamesForComputingUtilization = JsonUtil.getListOfStrings(rawConfig, + LoadBalancerConfigFactory.METRIC_NAMES_FOR_COMPUTING_UTILIZATION); WeightedRoundRobinLoadBalancerConfig.Builder configBuilder = - WeightedRoundRobinLoadBalancerConfig.newBuilder(); + WeightedRoundRobinLoadBalancerConfig.newBuilder(); if (blackoutPeriodNanos != null) { configBuilder.setBlackoutPeriodNanos(blackoutPeriodNanos); } @@ -102,6 +105,11 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map rawC if (errorUtilizationPenalty != null) { configBuilder.setErrorUtilizationPenalty(errorUtilizationPenalty); } + if (metricNamesForComputingUtilization != null) { + if (WeightedRoundRobinLoadBalancer.enableCustomConfig) { + configBuilder.setMetricNamesForComputingUtilization(metricNamesForComputingUtilization); + } + } return ConfigOrError.fromConfig(configBuilder.build()); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java b/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java new file mode 100644 index 00000000000..eb3d6045c1f --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/MetricReportUtils.java @@ -0,0 +1,80 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import io.grpc.services.MetricReport; +import java.util.Map; +import java.util.OptionalDouble; + +/** + * Utilities for parsing and resolving metrics from {@link MetricReport}. + */ +public final class MetricReportUtils { + + private MetricReportUtils() {} + + /** + * Resolves a metric value from the report based on the given metric name. + * The logic checks for specific prefixes to determine where to look up the metric: + *

    + *
  • "cpu_utilization" -> getCpuUtilization()
  • + *
  • "application_utilization" -> getApplicationUtilization()
  • + *
  • "memory_utilization" -> getMemoryUtilization()
  • + *
  • "qps" -> getQps()
  • + *
  • "eps" -> getEps()
  • + *
  • "utilization." -> lookup in utilizationMetrics
  • + *
  • "request_cost." -> lookup in requestCostMetrics
  • + *
  • "named_metrics." -> lookup in namedMetrics
  • + *
+ * + * @param report The metric report to query. + * @param metricName The name of the custom metric to look up. + * @return The value of the metric if found, or empty if not found. + */ + public static OptionalDouble getMetric(MetricReport report, String metricName) { + if (metricName.equals("cpu_utilization")) { + return OptionalDouble.of(report.getCpuUtilization()); + } else if (metricName.equals("application_utilization")) { + return OptionalDouble.of(report.getApplicationUtilization()); + } else if (metricName.equals("memory_utilization")) { + return OptionalDouble.of(report.getMemoryUtilization()); + } else if (metricName.equals("qps")) { + return OptionalDouble.of(report.getQps()); + } else if (metricName.equals("eps")) { + return OptionalDouble.of(report.getEps()); + } else if (metricName.startsWith("utilization.")) { + Map map = report.getUtilizationMetrics(); + Double val = map.get(metricName.substring("utilization.".length())); + if (val != null) { + return OptionalDouble.of(val); + } + } else if (metricName.startsWith("request_cost.")) { + Map map = report.getRequestCostMetrics(); + Double val = map.get(metricName.substring("request_cost.".length())); + if (val != null) { + return OptionalDouble.of(val); + } + } else if (metricName.startsWith("named_metrics.")) { + Map map = report.getNamedMetrics(); + Double val = map.get(metricName.substring("named_metrics.".length())); + if (val != null) { + return OptionalDouble.of(val); + } + } + return OptionalDouble.empty(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java index e09066461c4..b8b20248026 100644 --- a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java @@ -101,6 +101,22 @@ public class LoadBalancerConfigFactoryTest { .build())) .build()) .build(); + + private static final Policy WRR_POLICY_WITH_METRICS = Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setName("backend") + .setTypedConfig( + Any.pack(ClientSideWeightedRoundRobin.newBuilder() + .setBlackoutPeriod(Duration.newBuilder().setSeconds(287).build()) + .setEnableOobLoadReport( + BoolValue.newBuilder().setValue(true).build()) + .setErrorUtilizationPenalty( + FloatValue.newBuilder().setValue(1.75F).build()) + .addMetricNamesForComputingUtilization("foo") + .addMetricNamesForComputingUtilization("bar") + .build())) + .build()) + .build(); private static final String CUSTOM_POLICY_NAME = "myorg.MyCustomLeastRequestPolicy"; private static final String CUSTOM_POLICY_FIELD_KEY = "choiceCount"; private static final double CUSTOM_POLICY_FIELD_VALUE = 2; @@ -130,6 +146,15 @@ public class LoadBalancerConfigFactoryTest { ImmutableMap.of("weighted_round_robin", ImmutableMap.of("blackoutPeriod","287s", "enableOobLoadReport", true, "errorUtilizationPenalty", 1.75F ))))); + + private static final LbConfig VALID_WRR_CONFIG_WITH_METRICS = + new LbConfig("wrr_locality_experimental", + ImmutableMap.of("childPolicy", + ImmutableList.of(ImmutableMap.of("weighted_round_robin", + ImmutableMap.of("blackoutPeriod", "287s", "enableOobLoadReport", true, + "errorUtilizationPenalty", 1.75F, + LoadBalancerConfigFactory.METRIC_NAMES_FOR_COMPUTING_UTILIZATION, + ImmutableList.of("foo", "bar")))))); private static final LbConfig VALID_RING_HASH_CONFIG = new LbConfig("ring_hash_experimental", ImmutableMap.of("minRingSize", (double) RING_HASH_MIN_RING_SIZE, "maxRingSize", (double) RING_HASH_MAX_RING_SIZE)); @@ -165,6 +190,13 @@ public void weightedRoundRobin() throws ResourceInvalidException { assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_WRR_CONFIG); } + @Test + public void weightedRoundRobin_withMetrics() throws ResourceInvalidException { + Cluster cluster = newCluster(buildWrrPolicy(WRR_POLICY_WITH_METRICS)); + + assertThat(newLbConfig(cluster, true)).isEqualTo(VALID_WRR_CONFIG_WITH_METRICS); + } + @Test public void weightedRoundRobin_invalid() throws ResourceInvalidException { Cluster cluster = newCluster(buildWrrPolicy(Policy.newBuilder() diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java index ddde84ca842..852502fd415 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProviderTest.java @@ -111,6 +111,23 @@ public void parseLoadBalancingConfigDefaultValues() throws IOException { assertThat(config.errorUtilizationPenalty).isEqualTo(1.0F); } + @Test + public void parseLoadBalancingConfigCustomMetrics() throws IOException { + boolean originalEnableCustomConfig = WeightedRoundRobinLoadBalancer.enableCustomConfig; + WeightedRoundRobinLoadBalancer.enableCustomConfig = true; + try { + String lbConfig = "{\"metricNamesForComputingUtilization\" : [\"foo\", \"bar\"]}"; + ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig( + parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + WeightedRoundRobinLoadBalancerConfig config = + (WeightedRoundRobinLoadBalancerConfig) configOrError.getConfig(); + assertThat(config.metricNamesForComputingUtilization).containsExactly("foo", "bar"); + } finally { + WeightedRoundRobinLoadBalancer.enableCustomConfig = originalEnableCustomConfig; + } + } + @SuppressWarnings("unchecked") private static Map parseJsonObject(String json) throws IOException { diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index 9fac46eaf09..46c1e736494 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -19,10 +19,10 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.CONNECTING; import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Mockito.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeast; -import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -33,6 +33,7 @@ import com.github.xds.data.orca.v3.OrcaLoadReport; import com.github.xds.service.orca.v3.OrcaLoadReportRequest; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -285,10 +286,12 @@ public void wrrLifeCycle() { WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; @@ -340,10 +343,12 @@ public void enableOobLoadReportConfig() { (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.9, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; @@ -399,9 +404,12 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r1); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r2); - weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r3); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport(r1); + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport(r2); + weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport(r3); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); Map pickCount = new HashMap<>(); @@ -598,10 +606,12 @@ public void blackoutPeriod() { (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedCount = isEnabledHappyEyeballs() ? 2 : 1; @@ -661,10 +671,12 @@ public void updateWeightTimer() { assertThat(weightedPicker.getChildren().size()).isEqualTo(2); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; @@ -678,10 +690,12 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); assertThat(getNumFilteredPendingTasks()).isEqualTo(1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); //timer fires, new weight updated @@ -713,10 +727,12 @@ public void weightExpired() { (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; @@ -819,10 +835,12 @@ public void unknownWeightIsAvgWeight() { (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); @@ -860,10 +878,12 @@ public void pickFromOtherThread() throws Exception { (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); - weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty, + weightedConfig.metricNamesForComputingUtilization).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); CyclicBarrier barrier = new CyclicBarrier(2); @@ -1098,7 +1118,7 @@ public void testImmediateWraparound() { .isLessThan(0.002); } } - + @Test public void testWraparound() { float[] weights = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; @@ -1199,22 +1219,22 @@ public void metrics() { // Send one child LB state an ORCA update with some valid utilization/qps data so that weights // can be calculated, but it's still essentially round_robin Iterator childLbStates = wrr.getChildLbStates().iterator(); - ((WeightedChildLbState)childLbStates.next()).new OrcaReportListener( - weightedConfig.errorUtilizationPenalty).onLoadReport( - InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), - new HashMap<>(), new HashMap<>())); + ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization) + .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), new HashMap<>())); fakeClock.forwardTime(1, TimeUnit.SECONDS); // Now send a second child LB state an ORCA update, so there's real weights - ((WeightedChildLbState)childLbStates.next()).new OrcaReportListener( - weightedConfig.errorUtilizationPenalty).onLoadReport( - InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), - new HashMap<>(), new HashMap<>())); - ((WeightedChildLbState)childLbStates.next()).new OrcaReportListener( - weightedConfig.errorUtilizationPenalty).onLoadReport( - InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), - new HashMap<>(), new HashMap<>())); + ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization) + .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), new HashMap<>())); + ((WeightedChildLbState) childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization) + .onLoadReport(InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), new HashMap<>())); // Let's reset the mock MetricsRecorder so that it's easier to verify what happened after the // weights were updated @@ -1312,6 +1332,218 @@ public void metricWithRealChannel() throws Exception { eq(Arrays.asList("", ""))); } + + @Test + public void customMetric_priority_appUtilStillPreferred() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", 0.5); + // App util = 0.8 + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0.8, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.8 -> weight=1.25 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(1.25), any(), + any()); + } + + + + @Test + public void customMetric_mapLookup_used() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", 0.5); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.5 -> weight=2.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(2.0), any(), + any()); + } + + @Test + public void customMetric_shouldFilterOutAndFallbackToCpu() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost")).build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + // custom metric is NaN, but CPU is 0.1 + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + + // Should fallback to CPU (0.1) + // fallback to cpu: qps=1, util=0.1 -> weight=10.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(10.0), any(), + any()); + } + + @Test + public void customMetric_multipleMetrics_maxUsed() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization( + ImmutableList.of("named_metrics.cost", "named_metrics.score")) + .build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", 0.5); + namedMetrics.put("score", 0.8); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.8 (max of 0.5 and 0.8) -> weight=1.25 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(1.25), any(), + any()); + } + + @Test + public void customMetric_allInvalid_fallbackToCpu() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization( + ImmutableList.of("named_metrics.cost", "named_metrics.score", "named_metrics.other")) + .build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + namedMetrics.put("score", 0.0); + namedMetrics.put("other", -1.0); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.1 (fallback to cpu) -> weight=10.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(10.0), any(), + any()); + } + + @Test + public void customMetric_mixInvalidAndValid_validUsed() { + weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().setBlackoutPeriodNanos(0) + .setMetricNamesForComputingUtilization(ImmutableList.of("named_metrics.cost", + "named_metrics.score", "named_metrics.other1", "named_metrics.other2")) + .build(); + wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker()); + + syncContext.execute( + () -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(servers) + .setLoadBalancingPolicyConfig(weightedConfig).setAttributes(affinity).build())); + + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel = it.next(); + getSubchannelStateListener(readySubchannel) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + + WeightedChildLbState weightedChild = + (WeightedChildLbState) wrr.getChildLbStates().iterator().next(); + WeightedChildLbState.OrcaReportListener listener = weightedChild.getOrCreateOrcaListener( + weightedConfig.errorUtilizationPenalty, weightedConfig.metricNamesForComputingUtilization); + + Map namedMetrics = new HashMap<>(); + namedMetrics.put("cost", Double.NaN); + namedMetrics.put("score", 0.5); + namedMetrics.put("other1", 0.0); + namedMetrics.put("other2", -123.0); + MetricReport report = InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, + new HashMap<>(), new HashMap<>(), namedMetrics); + listener.onLoadReport(report); + // qps=1, util=0.5 -> weight=2.0 + fakeClock.forwardTime(1100, TimeUnit.MILLISECONDS); + verify(mockMetricRecorder).recordDoubleHistogram( + argThat(instr -> instr.getName().equals("grpc.lb.wrr.endpoint_weights")), eq(2.0), any(), + any()); + } + + // Verifies that the MetricRecorder has been called to record a long counter value of 1 for the // given metric name, the given number of times private void verifyLongCounterRecord(String name, int times, long value) { diff --git a/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java b/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java new file mode 100644 index 00000000000..12cf5941ae5 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/MetricReportUtilsTest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import io.grpc.services.InternalCallMetricRecorder; +import io.grpc.services.MetricReport; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.OptionalDouble; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link MetricReportUtils}. */ +@RunWith(JUnit4.class) +public class MetricReportUtilsTest { + + @Test + public void getMetric_cpuUtilization() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + OptionalDouble result = MetricReportUtils.getMetric(report, "cpu_utilization"); + assertTrue(result.isPresent()); + assertEquals(0.5, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetric_applicationUtilization() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + OptionalDouble result = MetricReportUtils.getMetric(report, "application_utilization"); + assertTrue(result.isPresent()); + assertEquals(0.1, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetric_memoryUtilization() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + OptionalDouble result = MetricReportUtils.getMetric(report, "memory_utilization"); + assertTrue(result.isPresent()); + assertEquals(0.2, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetric_qps() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + OptionalDouble result = MetricReportUtils.getMetric(report, "qps"); + assertTrue(result.isPresent()); + assertEquals(10.0, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetric_eps() { + MetricReport report = createMetricReport(0.5, 0.1, 0.2, 10.0, 5.0, Collections.emptyMap()); + OptionalDouble result = MetricReportUtils.getMetric(report, "eps"); + assertTrue(result.isPresent()); + assertEquals(5.0, result.getAsDouble(), 0.0001); + } + + @Test + public void getMetric_utilizationMetric() { + Map utilizationMetrics = new HashMap<>(); + utilizationMetrics.put("foo", 1.23); + MetricReport report = InternalCallMetricRecorder.createMetricReport( + 0, 0, 0, 0, 0, Collections.emptyMap(), utilizationMetrics, Collections.emptyMap()); + + OptionalDouble result = MetricReportUtils.getMetric(report, "utilization.foo"); + assertTrue(result.isPresent()); + assertEquals(1.23, result.getAsDouble(), 0.0001); + + assertFalse(MetricReportUtils.getMetric(report, "utilization.bar").isPresent()); + } + + @Test + public void getMetric_requestCostMetric() { + Map requestCostMetrics = new HashMap<>(); + requestCostMetrics.put("foo", 4.56); + MetricReport report = InternalCallMetricRecorder.createMetricReport( + 0, 0, 0, 0, 0, requestCostMetrics, Collections.emptyMap(), Collections.emptyMap()); + + OptionalDouble result = MetricReportUtils.getMetric(report, "request_cost.foo"); + assertTrue(result.isPresent()); + assertEquals(4.56, result.getAsDouble(), 0.0001); + + assertFalse(MetricReportUtils.getMetric(report, "request_cost.bar").isPresent()); + } + + @Test + public void getMetric_namedMetric() { + Map namedMetrics = new HashMap<>(); + namedMetrics.put("foo", 7.89); + MetricReport report = createMetricReport(0, 0, 0, 0, 0, namedMetrics); + + OptionalDouble result = MetricReportUtils.getMetric(report, "named_metrics.foo"); + assertTrue(result.isPresent()); + assertEquals(7.89, result.getAsDouble(), 0.0001); + + assertFalse(MetricReportUtils.getMetric(report, "named_metrics.bar").isPresent()); + } + + @Test + public void getMetric_unknownPrefix() { + MetricReport report = createMetricReport(0, 0, 0, 0, 0, Collections.emptyMap()); + assertFalse(MetricReportUtils.getMetric(report, "unknown.foo").isPresent()); + assertFalse(MetricReportUtils.getMetric(report, "foo").isPresent()); + } + + private MetricReport createMetricReport(double cpu, double app, double mem, double qps, + double eps, Map namedMetrics) { + return InternalCallMetricRecorder.createMetricReport( + cpu, app, mem, qps, eps, Collections.emptyMap(), Collections.emptyMap(), namedMetrics); + } +}