Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/bedrock_agentcore/runtime/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
from ..config_bundle.baggage import _extract_baggage
from .context import BedrockAgentCoreContext
from .models import (
_AUTHORIZATION_HEADER_LOWER,
ACCESS_TOKEN_HEADER,
AGENTCORE_RUNTIME_URL_ENV,
AUTHORIZATION_HEADER,
BAGGAGE_KEY_EXPERIMENT_ARN,
BAGGAGE_KEY_EXPERIMENT_VARIANT,
CUSTOM_HEADER_PREFIX,
OAUTH2_CALLBACK_URL_HEADER,
REQUEST_ID_HEADER,
SESSION_HEADER,
PingStatus,
is_forwardable_header,
)
from .tracing import _ensure_baggage_processor_registered

Expand Down Expand Up @@ -131,12 +132,15 @@ def build(self, request: Any) -> Any:
if oauth2_callback_url:
BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url)

# Collect forwardable request headers.
# Authorization is normalised to a canonical key regardless of wire casing
# (HTTP/2 always lowercases headers; HTTP/1.1 may preserve mixed case).
# All other headers are checked against the runtime header allowlist rules.
request_headers: dict[str, str] = {}
authorization_header = headers.get(AUTHORIZATION_HEADER)
if authorization_header is not None:
request_headers[AUTHORIZATION_HEADER] = authorization_header
for header_name, header_value in headers.items():
if header_name.lower().startswith(CUSTOM_HEADER_PREFIX.lower()):
if header_name.lower() == _AUTHORIZATION_HEADER_LOWER:
request_headers[AUTHORIZATION_HEADER] = header_value
elif is_forwardable_header(header_name):
request_headers[header_name] = header_value
if request_headers:
BedrockAgentCoreContext.set_request_headers(request_headers)
Expand Down
14 changes: 9 additions & 5 deletions src/bedrock_agentcore/runtime/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@
from ..config_bundle.baggage import _extract_baggage
from .context import BedrockAgentCoreContext, RequestContext
from .models import (
_AUTHORIZATION_HEADER_LOWER,
ACCESS_TOKEN_HEADER,
AUTHORIZATION_HEADER,
BAGGAGE_KEY_EXPERIMENT_ARN,
BAGGAGE_KEY_EXPERIMENT_VARIANT,
CUSTOM_HEADER_PREFIX,
OAUTH2_CALLBACK_URL_HEADER,
REQUEST_ID_HEADER,
SESSION_HEADER,
PingStatus,
is_forwardable_header,
)
from .tracing import _ensure_baggage_processor_registered

Expand Down Expand Up @@ -169,12 +170,15 @@ def _build_request_context(self, request: Request | WebSocket) -> RequestContext
if oauth2_callback_url:
BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url)

# Collect forwardable request headers.
# Authorization is normalised to a canonical key regardless of wire casing
# (HTTP/2 always lowercases headers; HTTP/1.1 may preserve mixed case).
# All other headers are checked against the runtime header allowlist rules.
request_headers: dict[str, str] = {}
authorization_header = headers.get(AUTHORIZATION_HEADER)
if authorization_header is not None:
request_headers[AUTHORIZATION_HEADER] = authorization_header
for header_name, header_value in headers.items():
if header_name.lower().startswith(CUSTOM_HEADER_PREFIX.lower()):
if header_name.lower() == _AUTHORIZATION_HEADER_LOWER:
request_headers[AUTHORIZATION_HEADER] = header_value
elif is_forwardable_header(header_name):
request_headers[header_name] = header_value
if request_headers:
BedrockAgentCoreContext.set_request_headers(request_headers)
Expand Down
18 changes: 9 additions & 9 deletions src/bedrock_agentcore/runtime/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
from ..config_bundle.client import ConfigBundleClient
from .context import BedrockAgentCoreContext, RequestContext
from .models import (
_AUTHORIZATION_HEADER_LOWER,
ACCESS_TOKEN_HEADER,
AUTHORIZATION_HEADER,
BAGGAGE_KEY_EXPERIMENT_ARN,
BAGGAGE_KEY_EXPERIMENT_VARIANT,
CUSTOM_HEADER_PREFIX,
OAUTH2_CALLBACK_URL_HEADER,
REQUEST_ID_HEADER,
SESSION_HEADER,
Expand All @@ -45,6 +45,7 @@
TASK_ACTION_JOB_STATUS,
TASK_ACTION_PING_STATUS,
PingStatus,
is_forwardable_header,
)
from .tracing import _ensure_baggage_processor_registered
from .utils import convert_complex_objects
Expand Down Expand Up @@ -415,17 +416,16 @@ def _build_request_context(self, request) -> RequestContext:
if oauth2_callback_url:
BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url)

# Collect relevant request headers (Authorization + Custom headers)
# Collect forwardable request headers.
# Authorization is normalised to a canonical key regardless of wire casing
# (HTTP/2 always lowercases headers; HTTP/1.1 may preserve mixed case).
# All other headers are checked against the runtime header allowlist rules.
request_headers = {}

# Add Authorization header if present
authorization_header = headers.get(AUTHORIZATION_HEADER)
if authorization_header is not None:
request_headers[AUTHORIZATION_HEADER] = authorization_header

# Add custom headers with the specified prefix
for header_name, header_value in headers.items():
if header_name.lower().startswith(CUSTOM_HEADER_PREFIX.lower()):
if header_name.lower() == _AUTHORIZATION_HEADER_LOWER:
request_headers[AUTHORIZATION_HEADER] = header_value
elif is_forwardable_header(header_name):
request_headers[header_name] = header_value

# Set in context if any headers were found
Expand Down
157 changes: 157 additions & 0 deletions src/bedrock_agentcore/runtime/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,140 @@ class PingStatus(str, Enum):
CUSTOM_HEADER_PREFIX = "X-Amzn-Bedrock-AgentCore-Runtime-Custom-"
AGENTCORE_RUNTIME_URL_ENV = "AGENTCORE_RUNTIME_URL"

# Headers that cannot be forwarded to agent code.
# Source: https://docs.aws.amazon.com/bedrock-agentcore/latest/devguide/runtime-header-allowlist.html
RESTRICTED_HEADERS: frozenset[str] = frozenset(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if this gets out of sync?

h.lower()
for h in [
# Authentication & Authorization
"Proxy-Authorization",
"WWW-Authenticate",
# Content Negotiation
"Accept",
"Accept-Charset",
"Accept-Encoding",
"Accept-Language",
"Content-Type",
"Content-Length",
"Content-Encoding",
"Content-Language",
"Content-Location",
"Content-Range",
# Caching
"Cache-Control",
"ETag",
"Expires",
"If-Match",
"If-Modified-Since",
"If-None-Match",
"If-Range",
"If-Unmodified-Since",
"Last-Modified",
"Pragma",
"Vary",
# Connection Management
"Connection",
"Keep-Alive",
"Proxy-Connection",
"Upgrade",
# Request Context
"Host",
"User-Agent",
"Referer",
"From",
# Range / Transfer
"Range",
"Accept-Ranges",
"Transfer-Encoding",
"TE",
"Trailer",
# Server Information
"Server",
"Date",
"Location",
"Retry-After",
# Cookies
"Set-Cookie",
"Cookie",
# Security
"Content-Security-Policy",
"Content-Security-Policy-Report-Only",
"Strict-Transport-Security",
"X-Content-Type-Options",
"X-Frame-Options",
"X-XSS-Protection",
"Referrer-Policy",
"Permissions-Policy",
"Cross-Origin-Embedder-Policy",
"Cross-Origin-Opener-Policy",
"Cross-Origin-Resource-Policy",
# CORS
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
"Access-Control-Allow-Headers",
"Access-Control-Allow-Credentials",
"Access-Control-Expose-Headers",
"Access-Control-Max-Age",
"Access-Control-Request-Method",
"Access-Control-Request-Headers",
"Origin",
# Client Hints
"Accept-CH",
"Accept-CH-Lifetime",
"DPR",
"Width",
"Viewport-Width",
"Downlink",
"ECT",
"RTT",
"Save-Data",
# Experimental / Proposed
"Clear-Site-Data",
"Feature-Policy",
"Expect-CT",
"Public-Key-Pins",
"Public-Key-Pins-Report-Only",
# Proxy
"Via",
"Forwarded",
"X-Forwarded-For",
"X-Forwarded-Host",
"X-Forwarded-Proto",
"X-Real-IP",
"X-Requested-With",
"X-CSRF-Token",
# IP Spoofing / URL Manipulation
"True-Client-IP",
"X-Client-IP",
"X-Cluster-Client-IP",
"X-Originating-IP",
"X-Source-IP",
"X-Original-URL",
"X-Original-Host",
"X-Rewrite-URL",
# CDN / Proxy
"CF-Ray",
"CF-Connecting-IP",
"X-Amz-Cf-Id",
"X-Cache",
"X-Served-By",
# HTTP/2 Pseudo Headers
":method",
":path",
":scheme",
":authority",
":status",
# Server Push
"Link",
# WebSocket
"Sec-WebSocket-Key",
"Sec-WebSocket-Accept",
"Sec-WebSocket-Version",
"Sec-WebSocket-Protocol",
"Sec-WebSocket-Extensions",
]
)

# Baggage keys for routing experiment span attributes
BAGGAGE_KEY_EXPERIMENT_ARN = "aws.agentcore.gateway.routing_experiment_arn"
BAGGAGE_KEY_EXPERIMENT_VARIANT = "aws.agentcore.gateway.routing_experiment_variant_name"
Expand All @@ -32,3 +166,26 @@ class PingStatus(str, Enum):
TASK_ACTION_FORCE_HEALTHY = "force_healthy"
TASK_ACTION_FORCE_BUSY = "force_busy"
TASK_ACTION_CLEAR_FORCED_STATUS = "clear_forced_status"


_CUSTOM_HEADER_PREFIX_LOWER = CUSTOM_HEADER_PREFIX.lower()
_AUTHORIZATION_HEADER_LOWER = AUTHORIZATION_HEADER.lower()


def is_forwardable_header(header_name: str) -> bool:
"""Return True if the header may be forwarded to agent code.

Rules (from the AgentCore runtime header allowlist docs):
- Not in the restricted headers list
- Does not start with ``x-amz-`` (reserved for AWS SigV4 signing)
- Does not start with ``x-amzn-`` unless it starts with the legacy
``X-Amzn-Bedrock-AgentCore-Runtime-Custom-`` prefix
"""
lower = header_name.lower()
if lower in RESTRICTED_HEADERS:
return False
if lower.startswith("x-amz-"):
return False
if lower.startswith("x-amzn-") and not lower.startswith(_CUSTOM_HEADER_PREFIX_LOWER):
return False
return True
33 changes: 33 additions & 0 deletions tests/bedrock_agentcore/runtime/test_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,39 @@ def _test():

self._run_in_isolated_context(_test)

def test_forwards_non_restricted_custom_headers(self):
"""Non-restricted headers (e.g. X-Api-Key) are forwarded; restricted ones are not."""

def _test():
from starlette.requests import Request

scope = {
"type": "http",
"method": "POST",
"path": "/",
"headers": [
(b"x-api-key", b"my-key"),
(b"x-custom-signature", b"sha256=abc"),
(b"content-type", b"application/json"), # restricted
(b"x-amz-date", b"20250101T000000Z"), # restricted (x-amz-)
(b"x-amzn-trace-id", b"trace-123"), # restricted (x-amzn-)
],
"query_string": b"",
}
request = Request(scope)
builder = BedrockCallContextBuilder()
builder.build(request)

headers = BedrockAgentCoreContext.get_request_headers()
assert headers is not None
assert any(k.lower() == "x-api-key" for k in headers)
assert any(k.lower() == "x-custom-signature" for k in headers)
assert not any(k.lower() == "content-type" for k in headers)
assert not any(k.lower() == "x-amz-date" for k in headers)
assert not any(k.lower() == "x-amzn-trace-id" for k in headers)

self._run_in_isolated_context(_test)

def test_auto_generates_request_id_when_missing(self):
def _test():
from starlette.requests import Request
Expand Down
33 changes: 33 additions & 0 deletions tests/bedrock_agentcore/runtime/test_ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,39 @@ def _test():

self._run_in_isolated_context(_test)

def test_forwards_non_restricted_custom_headers(self):
"""Non-restricted headers (e.g. X-Api-Key) are forwarded; restricted ones are not."""

def _test():
from starlette.requests import Request

scope = {
"type": "http",
"method": "POST",
"path": "/invocations",
"headers": [
(b"x-api-key", b"my-key"),
(b"x-custom-signature", b"sha256=abc"),
(b"content-type", b"application/json"), # restricted
(b"x-amz-date", b"20250101T000000Z"), # restricted (x-amz-)
(b"x-amzn-trace-id", b"trace-123"), # restricted (x-amzn-)
],
"query_string": b"",
}
request = Request(scope)
app = AGUIApp()
app._build_request_context(request)

headers = BedrockAgentCoreContext.get_request_headers()
assert headers is not None
assert any(k.lower() == "x-api-key" for k in headers)
assert any(k.lower() == "x-custom-signature" for k in headers)
assert not any(k.lower() == "content-type" for k in headers)
assert not any(k.lower() == "x-amz-date" for k in headers)
assert not any(k.lower() == "x-amzn-trace-id" for k in headers)

self._run_in_isolated_context(_test)

def test_auto_generates_request_id_when_missing(self):
def _test():
from starlette.requests import Request
Expand Down
Loading
Loading