diff --git a/lambda/invoke_loop_test.go b/lambda/invoke_loop_test.go index af0e0c0c..4c305564 100644 --- a/lambda/invoke_loop_test.go +++ b/lambda/invoke_loop_test.go @@ -389,6 +389,46 @@ func TestContextDeserializationErrors(t *testing.T) { }`, string(record.responses[2])) } +func TestClientContextWithNestedCustomValues(t *testing.T) { + metadata := defaultInvokeMetadata() + metadata.clientContext = `{ + "Client": { + "app_title": "test", + "installation_id": "install1", + "app_version_code": "1.0", + "app_package_name": "com.test" + }, + "custom": { + "bedrockAgentCoreTargetId": "target-123", + "bedrockAgentCorePropagatedHeaders": {"x-id": "my-custom-id"} + } + }` + + ts, record := runtimeAPIServer(`{}`, 1, metadata) + defer ts.Close() + handler := NewHandler(func(ctx context.Context) (interface{}, error) { + lc, _ := lambdacontext.FromContext(ctx) + return lc.ClientContext, nil + }) + endpoint := strings.Split(ts.URL, "://")[1] + _ = startRuntimeAPILoop(endpoint, handler) + + expected := `{ + "Client": { + "installation_id": "install1", + "app_title": "test", + "app_version_code": "1.0", + "app_package_name": "com.test" + }, + "env": null, + "custom": { + "bedrockAgentCoreTargetId": "target-123", + "bedrockAgentCorePropagatedHeaders": "{\"x-id\": \"my-custom-id\"}" + } + }` + assert.JSONEq(t, expected, string(record.responses[0])) +} + type invalidPayload struct{} func (invalidPayload) MarshalJSON() ([]byte, error) { diff --git a/lambdacontext/context.go b/lambdacontext/context.go index d75d8282..f3e70399 100644 --- a/lambdacontext/context.go +++ b/lambdacontext/context.go @@ -11,6 +11,7 @@ package lambdacontext import ( "context" + "encoding/json" "os" "strconv" ) @@ -68,6 +69,35 @@ type ClientContext struct { Custom map[string]string `json:"custom"` } +// UnmarshalJSON implements custom JSON unmarshaling for ClientContext. +// This handles the case where values in the "custom" map are not strings +// (e.g. nested JSON objects), by serializing non-string values back to +// their JSON string representation. +func (cc *ClientContext) UnmarshalJSON(data []byte) error { + var raw struct { + Client ClientApplication `json:"Client"` + Env map[string]string `json:"env"` + Custom map[string]json.RawMessage `json:"custom"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + cc.Client = raw.Client + cc.Env = raw.Env + if raw.Custom != nil { + cc.Custom = make(map[string]string, len(raw.Custom)) + for k, v := range raw.Custom { + var s string + if err := json.Unmarshal(v, &s); err == nil { + cc.Custom[k] = s + } else { + cc.Custom[k] = string(v) + } + } + } + return nil +} + // CognitoIdentity is the cognito identity used by the calling application. type CognitoIdentity struct { CognitoIdentityID string diff --git a/lambdacontext/context_test.go b/lambdacontext/context_test.go new file mode 100644 index 00000000..28266373 --- /dev/null +++ b/lambdacontext/context_test.go @@ -0,0 +1,35 @@ +package lambdacontext + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClientContextUnmarshalJSON(t *testing.T) { + t.Run("non-string custom values are serialized to string", func(t *testing.T) { + input := `{ + "Client": {"installation_id": "install1"}, + "custom": { + "key1": "stringval", + "key2": {"nested": "object"}, + "key3": 42 + } + }` + var cc ClientContext + err := json.Unmarshal([]byte(input), &cc) + require.NoError(t, err) + assert.Equal(t, "install1", cc.Client.InstallationID) + assert.Equal(t, "stringval", cc.Custom["key1"]) + assert.JSONEq(t, `{"nested":"object"}`, cc.Custom["key2"]) + assert.Equal(t, "42", cc.Custom["key3"]) + }) + + t.Run("invalid JSON returns error", func(t *testing.T) { + var cc ClientContext + err := json.Unmarshal([]byte(`not valid json`), &cc) + assert.Error(t, err) + }) +}