Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def __init__(self, allowed_hosts: list[str]) -> None:
if not isinstance(allowed_hosts, list):
raise TypeError("Allowed hosts must be a list of strings")

for host in allowed_hosts:
if host.startswith("https://") or host.startswith("http://"):
raise ValueError("Allowed host value cannot contain 'https://' or 'http://' prefix")

self.allowed_hosts: set[str] = {x.lower() for x in allowed_hosts}

def get_allowed_hosts(self) -> list[str]:
Expand All @@ -35,6 +39,11 @@ def set_allowed_hosts(self, allowed_hosts: list[str]) -> None:
"""
if not isinstance(allowed_hosts, list):
raise TypeError("Allowed hosts must be a list of strings")

for host in allowed_hosts:
if host.startswith("https://") or host.startswith("http://"):
raise ValueError("Allowed host value cannot contain 'https://' or 'http://' prefix")

self.allowed_hosts = {x.lower() for x in allowed_hosts}

def is_url_host_valid(self, url: str) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from kiota_abstractions.authentication import ApiKeyAuthenticationProvider, AuthenticationProvider, KeyLocation

allowed_hosts = ["https://example.com"]
allowed_hosts = ["example.com"]


def test_initialization():
Expand Down Expand Up @@ -59,3 +59,17 @@ async def test_header_location_authentication(mock_request_information):
await provider.authenticate_request(mock_request_information)
assert "api_key" in mock_request_information.request_headers
assert mock_request_information.headers.get("api_key") == {"test_key_string"}


def test_https_prefix_in_allowed_host():
with pytest.raises(ValueError, match="Allowed host value cannot contain 'https://' or 'http://' prefix"):
ApiKeyAuthenticationProvider(
KeyLocation.Header, "test_key_string", "api_key", ["https://example.com"]
)


def test_http_prefix_in_allowed_host():
with pytest.raises(ValueError, match="Allowed host value cannot contain 'https://' or 'http://' prefix"):
ApiKeyAuthenticationProvider(
KeyLocation.Header, "test_key_string", "api_key", ["http://example.com"]
)