Skip to content

Commit 9a14569

Browse files
committed
feat(logger): add DurableContext unwrap support in lambda_context helper
- Add _unwrap_durable_context() helper to detect and unwrap DurableContext - Update build_lambda_context_model() to handle both LambdaContext and DurableContext - Add comprehensive tests for DurableContext support in logger decorators - Fixes compatibility with AWS Durable Execution SDK
1 parent 11b68ff commit 9a14569

File tree

2 files changed

+191
-6
lines changed

2 files changed

+191
-6
lines changed

aws_lambda_powertools/logging/lambda_context.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from typing import Any
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
if TYPE_CHECKING:
6+
from aws_lambda_powertools.utilities.typing import LambdaContext
27

38

49
class LambdaContextModel:
@@ -34,25 +39,47 @@ def __init__(
3439
self.function_request_id = function_request_id
3540

3641

42+
def _unwrap_durable_context(context: Any) -> LambdaContext:
43+
"""Unwrap Lambda Context from DurableContext if applicable.
44+
45+
Parameters
46+
----------
47+
context : object
48+
Lambda context object or DurableContext
49+
50+
Returns
51+
-------
52+
LambdaContext
53+
The unwrapped Lambda context
54+
"""
55+
# Check if this is a DurableContext by duck typing
56+
if hasattr(context, "lambda_context") and hasattr(context, "state"):
57+
return context.lambda_context
58+
59+
return context
60+
61+
3762
def build_lambda_context_model(context: Any) -> LambdaContextModel:
3863
"""Captures Lambda function runtime info to be used across all log statements
3964
4065
Parameters
4166
----------
4267
context : object
43-
Lambda context object
68+
Lambda context object or DurableContext
4469
4570
Returns
4671
-------
4772
LambdaContextModel
4873
Lambda context only with select fields
4974
"""
75+
# Unwrap DurableContext if applicable
76+
lambda_context = _unwrap_durable_context(context)
5077

5178
context = {
52-
"function_name": context.function_name,
53-
"function_memory_size": context.memory_limit_in_mb,
54-
"function_arn": context.invoked_function_arn,
55-
"function_request_id": context.aws_request_id,
79+
"function_name": lambda_context.function_name,
80+
"function_memory_size": lambda_context.memory_limit_in_mb,
81+
"function_arn": lambda_context.invoked_function_arn,
82+
"function_request_id": lambda_context.aws_request_id,
5683
}
5784

5885
return LambdaContextModel(**context)
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
"""Tests for Logger with DurableContext support."""
2+
3+
import io
4+
import json
5+
import random
6+
import string
7+
from collections import namedtuple
8+
from unittest.mock import Mock
9+
10+
import pytest
11+
12+
from aws_lambda_powertools import Logger
13+
from aws_lambda_powertools.utilities.typing import DurableContextProtocol
14+
15+
16+
@pytest.fixture
17+
def stdout():
18+
return io.StringIO()
19+
20+
21+
@pytest.fixture
22+
def lambda_context():
23+
lambda_context = {
24+
"function_name": "test",
25+
"memory_limit_in_mb": 128,
26+
"invoked_function_arn": "arn:aws:lambda:eu-west-1:809313241:function:test",
27+
"aws_request_id": "52fdfc07-2182-154f-163f-5f0f9a621d72",
28+
}
29+
return namedtuple("LambdaContext", lambda_context.keys())(*lambda_context.values())
30+
31+
32+
@pytest.fixture
33+
def service_name():
34+
chars = string.ascii_letters + string.digits
35+
return "".join(random.SystemRandom().choice(chars) for _ in range(15))
36+
37+
38+
def capture_logging_output(stdout):
39+
return json.loads(stdout.getvalue().strip())
40+
41+
42+
def capture_multiple_logging_statements_output(stdout):
43+
return [json.loads(line.strip()) for line in stdout.getvalue().split("\n") if line]
44+
45+
46+
@pytest.fixture
47+
def durable_context(lambda_context):
48+
"""Create a mock DurableContext with embedded Lambda context."""
49+
durable_ctx = Mock(spec=DurableContextProtocol)
50+
durable_ctx.lambda_context = lambda_context
51+
durable_ctx.state = Mock(operations=[{"id": "op1"}])
52+
return durable_ctx
53+
54+
55+
def test_inject_lambda_context_with_durable_context(durable_context, stdout, service_name):
56+
"""Test that inject_lambda_context works with DurableContext."""
57+
# GIVEN Logger is initialized
58+
logger = Logger(service=service_name, stream=stdout)
59+
60+
# WHEN a lambda function is decorated with logger and receives DurableContext
61+
@logger.inject_lambda_context
62+
def handler(event, context):
63+
logger.info("Hello from durable function")
64+
65+
handler({}, durable_context)
66+
67+
# THEN lambda contextual info from the unwrapped context should be in the logs
68+
log = capture_logging_output(stdout)
69+
70+
expected_logger_context_keys = (
71+
"function_name",
72+
"function_memory_size",
73+
"function_arn",
74+
"function_request_id",
75+
)
76+
for key in expected_logger_context_keys:
77+
assert key in log
78+
79+
# Verify the actual values match the embedded lambda_context
80+
assert log["function_name"] == durable_context.lambda_context.function_name
81+
assert log["function_memory_size"] == durable_context.lambda_context.memory_limit_in_mb
82+
assert log["function_arn"] == durable_context.lambda_context.invoked_function_arn
83+
assert log["function_request_id"] == durable_context.lambda_context.aws_request_id
84+
assert log["message"] == "Hello from durable function"
85+
86+
87+
def test_inject_lambda_context_with_durable_context_log_event(durable_context, stdout, service_name):
88+
"""Test that inject_lambda_context with log_event=True works with DurableContext."""
89+
# GIVEN Logger is initialized
90+
logger = Logger(service=service_name, stream=stdout)
91+
92+
test_event = {"test_key": "test_value"}
93+
94+
# WHEN a lambda function is decorated with log_event=True and receives DurableContext
95+
@logger.inject_lambda_context(log_event=True)
96+
def handler(event, context):
97+
logger.info("Processing event")
98+
99+
handler(test_event, durable_context)
100+
101+
# THEN both the event and lambda contextual info should be logged
102+
logs = capture_multiple_logging_statements_output(stdout)
103+
assert len(logs) >= 2 # At least event log and info log
104+
105+
# First log should be the event
106+
assert logs[0]["message"] == test_event
107+
108+
109+
def test_inject_lambda_context_with_durable_context_clear_state(durable_context, stdout, service_name):
110+
"""Test that inject_lambda_context with clear_state works with DurableContext."""
111+
# GIVEN Logger is initialized with custom keys
112+
logger = Logger(service=service_name, stream=stdout)
113+
logger.append_keys(custom_key="initial_value")
114+
115+
# WHEN a lambda function is decorated with clear_state=True and receives DurableContext
116+
@logger.inject_lambda_context(clear_state=True)
117+
def handler(event, context):
118+
logger.info("After clear state")
119+
120+
handler({}, durable_context)
121+
122+
# THEN the custom key should be cleared and lambda context should be present
123+
log = capture_logging_output(stdout)
124+
125+
# Lambda context fields should be present
126+
assert "function_name" in log
127+
assert log["function_name"] == durable_context.lambda_context.function_name
128+
129+
# Custom key should not be present (cleared)
130+
assert "custom_key" not in log or log.get("custom_key") != "initial_value"
131+
132+
133+
def test_inject_lambda_context_standard_context_still_works(lambda_context, stdout, service_name):
134+
"""Test that standard Lambda context still works (regression test)."""
135+
# GIVEN Logger is initialized
136+
logger = Logger(service=service_name, stream=stdout)
137+
138+
# WHEN a lambda function is decorated with logger and receives standard LambdaContext
139+
@logger.inject_lambda_context
140+
def handler(event, context):
141+
logger.info("Hello from standard lambda")
142+
143+
handler({}, lambda_context)
144+
145+
# THEN lambda contextual info should be in the logs
146+
log = capture_logging_output(stdout)
147+
148+
expected_logger_context_keys = (
149+
"function_name",
150+
"function_memory_size",
151+
"function_arn",
152+
"function_request_id",
153+
)
154+
for key in expected_logger_context_keys:
155+
assert key in log
156+
157+
assert log["function_name"] == lambda_context.function_name
158+
assert log["message"] == "Hello from standard lambda"

0 commit comments

Comments
 (0)