diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 187cd853cb27..50821a88d4c3 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -8,6 +8,7 @@ #### Bugs Fixed * Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243) +* Fixed bug where region names in `preferred_locations` and `excluded_locations` (client-level and per-request) were not matched tolerantly for differences in case, whitespace, hyphens, and underscores. See [PR 46937](https://github.com/Azure/azure-sdk-for-python/pull/46937) #### Other Changes * Reduced per-client memory overhead when partition-level circuit breaker (PPCB) is enabled by sharing the partition key range routing map cache across CosmosClient instances connected to the same endpoint, and stripping unused fields from cached partition key ranges using compact PKRange namedtuples. See [PR 46297](https://github.com/Azure/azure-sdk-for-python/pull/46297) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index ef498a27b82a..0367548c53d5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -37,11 +37,20 @@ logger = logging.getLogger("azure.cosmos.LocationCache") + +def _normalize_region_name(region_name: str | None) -> str: + if region_name is None: + return "" + normalized = "".join(str(region_name).strip().lower().split()) + return normalized.replace("-", "").replace("_", "") + + class EndpointOperationType(object): NoneType = "None" ReadType = "Read" WriteType = "Write" + class RegionalRoutingContext(object): def __init__(self, primary_endpoint: str): self.primary_endpoint: str = primary_endpoint @@ -58,6 +67,7 @@ def __eq__(self, other): def __str__(self): return "Primary: " + self.primary_endpoint + def get_regional_routing_contexts_by_loc(new_locations: list[dict[str, str]]): # construct from previous object regional_routing_contexts_by_location: OrderedDict[str, RegionalRoutingContext] = collections.OrderedDict() @@ -82,13 +92,15 @@ def get_regional_routing_contexts_by_loc(new_locations: list[dict[str, str]]): return regional_routing_contexts_by_location, locations_by_endpoints, parsed_locations + def _get_health_check_endpoints(regional_routing_contexts) -> Set[str]: # should use the endpoints in the order returned from gateway and only the ones specified in preferred locations preferred_endpoints = {context.get_primary() for context in regional_routing_contexts} return preferred_endpoints + def _get_applicable_regional_routing_contexts(regional_routing_contexts: list[RegionalRoutingContext], - location_name_by_endpoint: Mapping[str, str], + normalized_location_name_by_endpoint: Mapping[str, str], fall_back_regional_routing_context: RegionalRoutingContext, exclude_location_list: list[str], circuit_breaker_exclude_list: list[str], @@ -108,8 +120,8 @@ def _get_applicable_regional_routing_contexts(regional_routing_contexts: list[Re :param regional_routing_contexts: The initial list of regional contexts to filter. :type regional_routing_contexts: list[RegionalRoutingContext] - :param location_name_by_endpoint: A mapping from endpoint URL to location name. - :type location_name_by_endpoint: Mapping[str, str] + :param normalized_location_name_by_endpoint: A mapping from endpoint URL to normalized location name. + :type normalized_location_name_by_endpoint: Mapping[str, str] :param fall_back_regional_routing_context: The context to use as a fallback if all others are filtered out. :type fall_back_regional_routing_context: RegionalRoutingContext :param exclude_location_list: A list of location names to exclude, based on user configuration. @@ -121,11 +133,17 @@ def _get_applicable_regional_routing_contexts(regional_routing_contexts: list[Re :return: A filtered and reordered list of regional routing contexts. :rtype: list[RegionalRoutingContext] """ + normalized_excluded_locations = {_normalize_region_name(location) for location in exclude_location_list} + normalized_circuit_breaker_locations = { + _normalize_region_name(location) for location in circuit_breaker_exclude_list + } + # filter endpoints by excluded locations applicable_regional_routing_contexts = [] user_excluded_regional_routing_contexts = [] for regional_routing_context in regional_routing_contexts: - if location_name_by_endpoint.get(regional_routing_context.get_primary()) not in exclude_location_list: + normalized_location_name = normalized_location_name_by_endpoint.get(regional_routing_context.get_primary(), "") + if normalized_location_name not in normalized_excluded_locations: applicable_regional_routing_contexts.append(regional_routing_context) else: user_excluded_regional_routing_contexts.append(regional_routing_context) @@ -134,7 +152,8 @@ def _get_applicable_regional_routing_contexts(regional_routing_contexts: list[Re final_applicable_contexts = [] circuit_breaker_excluded_contexts = [] for regional_routing_context in applicable_regional_routing_contexts: - if location_name_by_endpoint.get(regional_routing_context.get_primary()) in circuit_breaker_exclude_list: + normalized_location_name = normalized_location_name_by_endpoint.get(regional_routing_context.get_primary(), "") + if normalized_location_name in normalized_circuit_breaker_locations: circuit_breaker_excluded_contexts.append(regional_routing_context) else: final_applicable_contexts.append(regional_routing_context) @@ -152,6 +171,7 @@ def _get_applicable_regional_routing_contexts(regional_routing_contexts: list[Re return final_applicable_contexts + class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes def __init__( @@ -172,7 +192,14 @@ def __init__( self.account_locations_by_write_endpoints: dict[str, str] = {} # pylint: disable=name-too-long self.account_write_locations: list[str] = [] self.account_read_locations: list[str] = [] + self._read_locations_by_normalized: dict[str, RegionalRoutingContext] = {} + self._write_locations_by_normalized: dict[str, RegionalRoutingContext] = {} + self._normalized_location_by_read_endpoint: dict[str, str] = {} + self._normalized_location_by_write_endpoint: dict[str, str] = {} + self._normalized_name_by_read_location: dict[str, str] = {} + self._normalized_name_by_write_location: dict[str, str] = {} self.connection_policy: ConnectionPolicy = connection_policy + self._config_mismatch_warning_dedupe: set[tuple[str, tuple[str, ...], tuple[str, ...]]] = set() def get_write_regional_routing_contexts(self): return self.write_regional_routing_contexts @@ -229,6 +256,38 @@ def _get_configured_excluded_locations(self, request: RequestObject) -> list[str return excluded_locations + def _emit_config_mismatch_warning_once( + self, + configured_locations: list[str], + available_locations: list[str], + setting_name: str): + if not configured_locations: + return + + available_by_normalized = {_normalize_region_name(location): location for location in available_locations} + unmatched_locations = [ + location + for location in configured_locations + if _normalize_region_name(location) not in available_by_normalized + ] + + if unmatched_locations: + dedupe_key = ( + setting_name, + tuple(sorted(_normalize_region_name(location) for location in unmatched_locations)), + tuple(sorted(available_by_normalized.keys())), + ) + if dedupe_key in self._config_mismatch_warning_dedupe: + return + self._config_mismatch_warning_dedupe.add(dedupe_key) + + logger.warning( + "Ignoring %s entries that did not match account regions: %s. Available regions: %s", + setting_name, + unmatched_locations, + available_locations, + ) + def _get_applicable_read_regional_routing_contexts(self, request: RequestObject) -> list[RegionalRoutingContext]: # Get configured excluded locations excluded_locations = self._get_configured_excluded_locations(request) @@ -237,7 +296,7 @@ def _get_applicable_read_regional_routing_contexts(self, request: RequestObject) if excluded_locations or request.excluded_locations_circuit_breaker: return _get_applicable_regional_routing_contexts( self.get_read_regional_routing_contexts(), - self.account_locations_by_read_endpoints, + self._normalized_location_by_read_endpoint, self.get_write_regional_routing_contexts()[0], excluded_locations, request.excluded_locations_circuit_breaker or [], @@ -254,7 +313,7 @@ def _get_applicable_write_regional_routing_contexts(self, request: RequestObject if excluded_locations or request.excluded_locations_circuit_breaker: return _get_applicable_regional_routing_contexts( self.get_write_regional_routing_contexts(), - self.account_locations_by_write_endpoints, + self._normalized_location_by_write_endpoint, self.default_regional_routing_context, excluded_locations, request.excluded_locations_circuit_breaker or [], @@ -292,6 +351,8 @@ def _resolve_endpoint_without_preferred_locations(self, request, is_write, locat ordered_locations = self.account_write_locations if is_write else self.account_read_locations all_contexts_by_loc = (self.account_write_regional_routing_contexts_by_location if is_write else self.account_read_regional_routing_contexts_by_location) + normalized_name_by_location = (self._normalized_name_by_write_location if is_write + else self._normalized_name_by_read_location) # Safety check: if endpoint discovery is off or location cache isn't populated, fallback. if not self.connection_policy.EnableEndpointDiscovery or not ordered_locations: @@ -309,14 +370,20 @@ def _resolve_endpoint_without_preferred_locations(self, request, is_write, locat excluded_locations = self._get_configured_excluded_locations(request) circuit_breaker_excluded_locations = request.excluded_locations_circuit_breaker or [] + normalized_excluded_locations = {_normalize_region_name(location) for location in excluded_locations} + normalized_circuit_breaker_locations = { + _normalize_region_name(location) for location in circuit_breaker_excluded_locations + } + applicable_contexts = [] circuit_breaker_contexts = [] for loc_name in ordered_locations: if loc_name in all_contexts_by_loc: context = all_contexts_by_loc[loc_name] - if loc_name in excluded_locations: + normalized_location_name = normalized_name_by_location.get(loc_name, "") + if normalized_location_name in normalized_excluded_locations: continue # Skip user-excluded locations - if loc_name in circuit_breaker_excluded_locations: + if normalized_location_name in normalized_circuit_breaker_locations: circuit_breaker_contexts.append(context) else: applicable_contexts.append(context) @@ -376,6 +443,11 @@ def resolve_service_endpoint(self, request): def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements most_preferred_location = self.effective_preferred_locations[0] if self.effective_preferred_locations else None + normalized_most_preferred_location = ( + _normalize_region_name(most_preferred_location) if most_preferred_location else None + ) + read_locations_by_normalized = self._read_locations_by_normalized + write_locations_by_normalized = self._write_locations_by_normalized # we should schedule refresh in background if we are unable to target the user's most preferredLocation. if self.connection_policy.EnableEndpointDiscovery: @@ -383,18 +455,13 @@ def should_refresh_endpoints(self): # pylint: disable=too-many-return-statement should_refresh = (self.connection_policy.UseMultipleWriteLocations and not self.enable_multiple_writable_locations) - if (most_preferred_location and most_preferred_location in - self.account_read_regional_routing_contexts_by_location): - if (self.account_read_regional_routing_contexts_by_location - and most_preferred_location in self.account_read_regional_routing_contexts_by_location): - most_preferred_read_endpoint = ( - self.account_read_regional_routing_contexts_by_location)[most_preferred_location] - if (most_preferred_read_endpoint and - most_preferred_read_endpoint != self.read_regional_routing_contexts[0]): - # For reads, we can always refresh in background as we can alternate to - # other available read endpoints - return True - else: + if (normalized_most_preferred_location and normalized_most_preferred_location in + read_locations_by_normalized): + most_preferred_read_endpoint = read_locations_by_normalized[normalized_most_preferred_location] + if (most_preferred_read_endpoint and + most_preferred_read_endpoint != self.read_regional_routing_contexts[0]): + # For reads, we can always refresh in background as we can alternate to + # other available read endpoints return True if not self.can_use_multiple_write_locations(): @@ -405,10 +472,11 @@ def should_refresh_endpoints(self): # pylint: disable=too-many-return-statement # we have an alternate write endpoint return True return should_refresh - if (most_preferred_location and - most_preferred_location in self.account_write_regional_routing_contexts_by_location): - most_preferred_write_regional_endpoint = ( - self.account_write_regional_routing_contexts_by_location)[most_preferred_location] + if (normalized_most_preferred_location and + normalized_most_preferred_location in write_locations_by_normalized): + most_preferred_write_regional_endpoint = write_locations_by_normalized[ + normalized_most_preferred_location + ] if most_preferred_write_regional_endpoint: should_refresh |= most_preferred_write_regional_endpoint != self.write_regional_routing_contexts[0] return should_refresh @@ -476,6 +544,32 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl self.account_locations_by_write_endpoints, self.account_write_locations) = get_regional_routing_contexts_by_loc(write_locations) + # Cache normalized lookups once per topology refresh to avoid repeating work per request. + self._read_locations_by_normalized = { + _normalize_region_name(name): context + for name, context in self.account_read_regional_routing_contexts_by_location.items() + } + self._write_locations_by_normalized = { + _normalize_region_name(name): context + for name, context in self.account_write_regional_routing_contexts_by_location.items() + } + self._normalized_location_by_read_endpoint = { + endpoint: _normalize_region_name(name) + for endpoint, name in self.account_locations_by_read_endpoints.items() + } + self._normalized_location_by_write_endpoint = { + endpoint: _normalize_region_name(name) + for endpoint, name in self.account_locations_by_write_endpoints.items() + } + self._normalized_name_by_read_location = { + name: _normalize_region_name(name) + for name in self.account_read_regional_routing_contexts_by_location + } + self._normalized_name_by_write_location = { + name: _normalize_region_name(name) + for name in self.account_write_regional_routing_contexts_by_location + } + # if preferred locations is empty and the default endpoint is a global endpoint, # we should use the read locations from gateway as effective preferred locations if self.connection_policy.PreferredLocations: @@ -489,17 +583,40 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl self.account_write_regional_routing_contexts_by_location, self.account_write_locations, EndpointOperationType.WriteType, - self.default_regional_routing_context + self.default_regional_routing_context, + self._write_locations_by_normalized, ) self.read_regional_routing_contexts = self.get_preferred_regional_routing_contexts( self.account_read_regional_routing_contexts_by_location, self.account_read_locations, EndpointOperationType.ReadType, - self.write_regional_routing_contexts[0] + self.write_regional_routing_contexts[0], + self._read_locations_by_normalized, ) + # Config-time visibility for misconfigured region names. Dedupe ensures periodic + # refreshes do not re-emit identical warnings; new mismatches still surface because + # the dedupe key includes the available account regions snapshot. + if self.connection_policy.PreferredLocations: + self._emit_config_mismatch_warning_once( + self.connection_policy.PreferredLocations, + self.account_read_locations or self.account_write_locations, + "preferred_locations", + ) + if self.connection_policy.ExcludedLocations: + self._emit_config_mismatch_warning_once( + list(self.connection_policy.ExcludedLocations), + self.account_read_locations or self.account_write_locations, + "excluded_locations", + ) + def get_preferred_regional_routing_contexts( - self, endpoints_by_location, ordered_locations, expected_available_operation, fallback_endpoint + self, + endpoints_by_location, + ordered_locations, + expected_available_operation, + fallback_endpoint, + endpoints_by_normalized_location=None, ): regional_endpoints = [] # if enableEndpointDiscovery is false, we always use the defaultEndpoint that @@ -511,12 +628,18 @@ def get_preferred_regional_routing_contexts( ): unavailable_endpoints = [] if self.effective_preferred_locations: + endpoints_by_normalized_location = endpoints_by_normalized_location or { + _normalize_region_name(location): endpoint + for location, endpoint in endpoints_by_location.items() + } + # When client can not use multiple write locations, preferred locations # list should only be used determining read endpoints order. If client # can use multiple write locations, preferred locations list should be # used for determining both read and write endpoints order. for location in self.effective_preferred_locations: - regional_endpoint = endpoints_by_location.get(location) + normalized_location = _normalize_region_name(location) + regional_endpoint = endpoints_by_normalized_location.get(normalized_location) if regional_endpoint: if self.is_endpoint_unavailable(regional_endpoint.get_primary(), expected_available_operation): @@ -593,8 +716,8 @@ def GetLocationalEndpoint(default_endpoint, location_name): global_database_account_name = hostname_parts[0] # Prepare the locational_database_account_name as contoso-eastus for location_name 'east us' - locational_database_account_name = global_database_account_name + "-" + location_name.replace(" ", "") - locational_database_account_name = locational_database_account_name.lower() + normalized_location_name = _normalize_region_name(location_name) + locational_database_account_name = global_database_account_name + "-" + normalized_location_name # Replace 'contoso' with 'contoso-eastus' and return locational_endpoint # as https://contoso-eastus.documents.azure.com:443/ diff --git a/sdk/cosmos/azure-cosmos/cspell.json b/sdk/cosmos/azure-cosmos/cspell.json index 913606d7e7ff..d71f63bc08b7 100644 --- a/sdk/cosmos/azure-cosmos/cspell.json +++ b/sdk/cosmos/azure-cosmos/cspell.json @@ -2,6 +2,7 @@ "ignoreWords": [ "hdrh", "hdrhistogram", + "dedupe", "perfdb", "perfresults", "pkrange", diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index e68736d78b5a..f771ec33a302 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -39,6 +39,26 @@ def create_database_account(enable_multiple_writable_locations): return db_acc +canonical_location1_name = "East US 2" +canonical_location2_name = "West US 3" +canonical_location1_endpoint = "https://eastus2.documents.azure.com" +canonical_location2_endpoint = "https://westus3.documents.azure.com" + + +def create_database_account_with_canonical_regions(enable_multiple_writable_locations): + db_acc = DatabaseAccount() + db_acc._WritableLocations = [ + {"name": canonical_location1_name, "databaseAccountEndpoint": canonical_location1_endpoint}, + {"name": canonical_location2_name, "databaseAccountEndpoint": canonical_location2_endpoint}, + ] + db_acc._ReadableLocations = [ + {"name": canonical_location1_name, "databaseAccountEndpoint": canonical_location1_endpoint}, + {"name": canonical_location2_name, "databaseAccountEndpoint": canonical_location2_endpoint}, + ] + db_acc._EnableMultipleWritableLocations = enable_multiple_writable_locations + return db_acc + + def refresh_location_cache(preferred_locations, use_multiple_write_locations, connection_policy=documents.ConnectionPolicy()): connection_policy.PreferredLocations = preferred_locations connection_policy.UseMultipleWriteLocations = use_multiple_write_locations @@ -670,5 +690,110 @@ def test_location_cache_derived_state_consistency(self): assert read_after_second == [ctx.get_primary() for ctx in expected_read] assert write_after_second == [ctx.get_primary() for ctx in expected_write] + def test_resolve_endpoint_without_preferred_locations_supports_normalized_exclusions(self): + # This specifically exercises _resolve_endpoint_without_preferred_locations by + # setting use_preferred_locations=False. + lc = refresh_location_cache( + preferred_locations=[], + use_multiple_write_locations=True, + ) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.use_preferred_locations = False + write_request.excluded_locations = ["east-us-2"] + + assert lc.resolve_service_endpoint(write_request) == canonical_location2_endpoint + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + read_request.use_preferred_locations = False + read_request.excluded_locations = ["west_us_3"] + + assert lc.resolve_service_endpoint(read_request) == canonical_location1_endpoint + + def test_preferred_locations_support_normalized_region_names(self): + # Preferred locations should match account region names even with case/spacing/separator variations. + lc = refresh_location_cache(["east-us-2", " west_us_3 "], True) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + write_contexts = lc.get_write_regional_routing_contexts() + read_contexts = lc.get_read_regional_routing_contexts() + + assert write_contexts[0].get_primary() == canonical_location1_endpoint + assert write_contexts[1].get_primary() == canonical_location2_endpoint + assert read_contexts[0].get_primary() == canonical_location1_endpoint + assert read_contexts[1].get_primary() == canonical_location2_endpoint + + def test_excluded_locations_support_normalized_region_names(self): + # Excluded locations should filter regions even when normalized names are used. + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = ["east-us-2"] + + lc = refresh_location_cache([canonical_location1_name, canonical_location2_name], True, connection_policy) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.excluded_locations = ["west_us_3"] + + assert lc.resolve_service_endpoint(read_request) == canonical_location2_endpoint + assert lc.resolve_service_endpoint(write_request) == canonical_location1_endpoint + + def test_should_refresh_endpoints_handles_normalized_preferred_region(self): + # should_refresh_endpoints must match canonical region keys even when the + # customer's preferred location uses non-canonical spelling. + lc = refresh_location_cache(["east-us-2"], True) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + # Most-preferred is already the primary; no background refresh should be triggered. + assert lc.should_refresh_endpoints() is False + + def test_get_locational_endpoint_normalizes_customer_region_string(self): + # GetLocationalEndpoint is used during bootstrap fallback with the customer-supplied + # preferred region string. It must produce the canonical regional URL for any + # accepted normalization variant. + default_endpoint_url = "https://contoso.documents.azure.com:443/" + expected_endpoint = "https://contoso-eastus2.documents.azure.com:443/" + + for region_input in ("East US 2", "east us 2", "eastus2", "east-us-2", "east_us_2", " EastUs2 "): + assert LocationCache.GetLocationalEndpoint(default_endpoint_url, region_input) == expected_endpoint + + def test_unmatched_excluded_locations_warning_is_deduped(self, caplog): + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = ["unknown-region"] + lc = refresh_location_cache([canonical_location1_name], True, connection_policy) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + with caplog.at_level("WARNING", logger="azure.cosmos.LocationCache"): + lc.perform_on_database_account_read(db_acc) + request = RequestObject(ResourceType.Document, _OperationType.Read, None) + lc.resolve_service_endpoint(request) + lc.resolve_service_endpoint(request) + # Simulate a periodic refresh with unchanged topology and config. + lc.perform_on_database_account_read(db_acc) + + unmatched_logs = [ + record for record in caplog.records + if "Ignoring excluded_locations entries" in record.getMessage() + ] + assert len(unmatched_logs) == 1 + + def test_unmatched_preferred_locations_warning_is_deduped(self, caplog): + with caplog.at_level("WARNING", logger="azure.cosmos.LocationCache"): + lc = refresh_location_cache(["unknown-region"], True) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + # Simulate a periodic refresh with unchanged topology and config. + lc.perform_on_database_account_read(db_acc) + + unmatched_logs = [ + record for record in caplog.records + if "Ignoring preferred_locations entries" in record.getMessage() + ] + assert len(unmatched_logs) == 1 + if __name__ == "__main__": unittest.main()