diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 187cd853cb27..90866e21a481 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -8,6 +8,7 @@ #### Bugs Fixed * Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243) +* Fixed bug where the SDK could not connect to the local Cosmos DB emulator running in Docker with a remapped host port. The emulator advertises its internal host/port (e.g. `127.0.0.1:8081`) in its account topology, which is unreachable when the host port differs from `8081`. When the user-supplied endpoint targets `localhost` or `127.0.0.1`, the SDK now reuses that host/port for all regional endpoints returned by the gateway. See [Issue 44380](https://github.com/Azure/azure-sdk-for-python/issues/44380) #### Other Changes * Reduced per-client memory overhead when partition-level circuit breaker (PPCB) is enabled by sharing the partition key range routing map cache across CosmosClient instances connected to the same endpoint, and stripping unused fields from cached partition key ranges using compact PKRange namedtuples. See [PR 46297](https://github.com/Azure/azure-sdk-for-python/pull/46297) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index ef498a27b82a..82ca01fa6f75 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -58,10 +58,50 @@ def __eq__(self, other): def __str__(self): return "Primary: " + self.primary_endpoint -def get_regional_routing_contexts_by_loc(new_locations: list[dict[str, str]]): +def _is_local_emulator_endpoint(endpoint: Optional[str]) -> bool: + """Return True if the endpoint refers to the local Cosmos DB emulator. + + Hosts ``localhost`` and ``127.0.0.1`` are treated as emulator endpoints. + """ + if not endpoint: + return False + try: + hostname = urlparse(endpoint).hostname + except ValueError: + return False + return hostname in ("localhost", "127.0.0.1") + + +def _rewrite_endpoint_with_default(default_endpoint: str, regional_endpoint: str) -> str: + """Rewrite ``regional_endpoint``'s scheme/host/port to match ``default_endpoint``. + + The Cosmos DB emulator advertises its internal host/port (for example + ``127.0.0.1:8081``) in the database account topology. When the emulator + is running in a container with a remapped port, that advertised endpoint + is unreachable from the host. Rewriting it to the user-supplied endpoint + preserves connectivity while keeping the rest of the URI (path, etc.) intact. + """ + try: + default_parsed = urlparse(default_endpoint) + regional_parsed = urlparse(regional_endpoint) + except ValueError: + return regional_endpoint + if not default_parsed.netloc: + return regional_endpoint + return regional_parsed._replace( + scheme=default_parsed.scheme or regional_parsed.scheme, + netloc=default_parsed.netloc, + ).geturl() + + +def get_regional_routing_contexts_by_loc( + new_locations: list[dict[str, str]], + default_endpoint: Optional[str] = None, +): # construct from previous object regional_routing_contexts_by_location: OrderedDict[str, RegionalRoutingContext] = collections.OrderedDict() parsed_locations = [] + rewrite_to_default = _is_local_emulator_endpoint(default_endpoint) for new_location in new_locations: # if name in new_location and same for database account endpoint @@ -71,6 +111,12 @@ def get_regional_routing_contexts_by_loc(new_locations: list[dict[str, str]]): continue try: region_uri = new_location["databaseAccountEndpoint"] + if rewrite_to_default and default_endpoint is not None: + # When targeting the local emulator the server can advertise an + # internal host/port (e.g. 127.0.0.1:8081) that is unreachable + # from the caller (common with Docker port remapping). Reuse + # the user-supplied endpoint host/port so connections succeed. + region_uri = _rewrite_endpoint_with_default(default_endpoint, region_uri) parsed_locations.append(new_location["name"]) regional_object = RegionalRoutingContext(region_uri) regional_routing_contexts_by_location.update({new_location["name"]: regional_object}) @@ -466,15 +512,18 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl self.enable_multiple_writable_locations = enable_multiple_writable_locations if self.connection_policy.EnableEndpointDiscovery: + default_endpoint = self.default_regional_routing_context.get_primary() if read_locations: (self.account_read_regional_routing_contexts_by_location, self.account_locations_by_read_endpoints, - self.account_read_locations) = get_regional_routing_contexts_by_loc(read_locations) + self.account_read_locations) = get_regional_routing_contexts_by_loc( + read_locations, default_endpoint) if write_locations: (self.account_write_regional_routing_contexts_by_location, self.account_locations_by_write_endpoints, - self.account_write_locations) = get_regional_routing_contexts_by_loc(write_locations) + self.account_write_locations) = get_regional_routing_contexts_by_loc( + write_locations, default_endpoint) # if preferred locations is empty and the default endpoint is a global endpoint, # we should use the read locations from gateway as effective preferred locations diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index e68736d78b5a..f008a2f5a4eb 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -670,5 +670,101 @@ def test_location_cache_derived_state_consistency(self): assert read_after_second == [ctx.get_primary() for ctx in expected_read] assert write_after_second == [ctx.get_primary() for ctx in expected_write] + +class TestEmulatorEndpointRewrite: + """Tests that emulator setups (localhost / 127.0.0.1) ignore the host:port + advertised by the gateway and reuse the user-supplied endpoint instead. + + This addresses the issue where the Cosmos emulator running in Docker with + a remapped port (e.g. host port 8888 -> container port 8081) advertises its + internal port back to the client, making the returned regional endpoints + unreachable from the host. + """ + + @staticmethod + def _make_db_account(advertised_endpoint): + db_acc = DatabaseAccount() + db_acc._WritableLocations = [ + {"name": "South Central US", "databaseAccountEndpoint": advertised_endpoint} + ] + db_acc._ReadableLocations = [ + {"name": "South Central US", "databaseAccountEndpoint": advertised_endpoint} + ] + db_acc._EnableMultipleWritableLocations = False + return db_acc + + @pytest.mark.parametrize("user_endpoint", [ + "http://localhost:8888/", + "http://127.0.0.1:9000/", + "https://localhost:8081/", + ]) + def test_emulator_endpoint_is_preserved(self, user_endpoint): + connection_policy = documents.ConnectionPolicy() + lc = LocationCache(default_endpoint=user_endpoint, connection_policy=connection_policy) + db_acc = self._make_db_account("https://127.0.0.1:8081/") + + lc.perform_on_database_account_read(db_acc) + + write_contexts = lc.get_write_regional_routing_contexts() + read_contexts = lc.get_read_regional_routing_contexts() + assert len(write_contexts) == 1 + assert len(read_contexts) == 1 + # The advertised 127.0.0.1:8081 host:port should be replaced with the + # user-supplied host:port so the SDK can reach the emulator. + assert write_contexts[0].get_primary() == user_endpoint + assert read_contexts[0].get_primary() == user_endpoint + + def test_non_emulator_endpoints_are_not_rewritten(self): + user_endpoint = "https://contoso.documents.azure.com:443/" + advertised_endpoint = "https://contoso-southcentralus.documents.azure.com:443/" + connection_policy = documents.ConnectionPolicy() + lc = LocationCache(default_endpoint=user_endpoint, connection_policy=connection_policy) + db_acc = self._make_db_account(advertised_endpoint) + + lc.perform_on_database_account_read(db_acc) + + write_contexts = lc.get_write_regional_routing_contexts() + read_contexts = lc.get_read_regional_routing_contexts() + assert write_contexts[0].get_primary() == advertised_endpoint + assert read_contexts[0].get_primary() == advertised_endpoint + + def test_emulator_endpoint_with_advertised_localhost_is_rewritten(self): + # Even when the advertised endpoint is also a localhost address (just + # with a different port like the in-container 8081), it should still + # be rewritten to the user-supplied host:port. + user_endpoint = "http://localhost:8888/" + advertised_endpoint = "http://localhost:8081/" + connection_policy = documents.ConnectionPolicy() + lc = LocationCache(default_endpoint=user_endpoint, connection_policy=connection_policy) + db_acc = self._make_db_account(advertised_endpoint) + + lc.perform_on_database_account_read(db_acc) + + write_contexts = lc.get_write_regional_routing_contexts() + assert write_contexts[0].get_primary() == user_endpoint + + def test_endpoint_discovery_disabled_skips_rewrite(self): + # When endpoint discovery is disabled, update_location_cache short-circuits + # before populating the per-region routing contexts at all, so the rewrite + # path is never reached and the SDK falls back to the user-supplied + # default endpoint for every request. + user_endpoint = "http://localhost:8888/" + advertised_endpoint = "https://127.0.0.1:8081/" + connection_policy = documents.ConnectionPolicy() + connection_policy.EnableEndpointDiscovery = False + lc = LocationCache(default_endpoint=user_endpoint, connection_policy=connection_policy) + db_acc = self._make_db_account(advertised_endpoint) + + lc.perform_on_database_account_read(db_acc) + + # No per-region contexts are populated when endpoint discovery is off. + assert lc.account_write_regional_routing_contexts_by_location == {} + assert lc.account_read_regional_routing_contexts_by_location == {} + # Routing falls back to the user-supplied default endpoint, not the + # gateway-advertised 127.0.0.1:8081. + assert lc.get_write_regional_routing_context() == user_endpoint + assert lc.get_read_regional_routing_context() == user_endpoint + + if __name__ == "__main__": unittest.main()