diff --git a/README.md b/README.md index 605d7c16..7a2b387a 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ The Client Callbacks infrastructure processes message and channel status events, - **Callbacks Event Bus**: Domain-specific EventBridge bus for webhook orchestration - **API Destination Target Rules**: Per-client rules invoking HTTPS endpoints with client-specific authentication - **Client Config Storage**: S3 bucket storing client subscription configurations (status filters, webhook endpoints) -- **Per-Client DLQs**: SQS Dead Letter Queues for failed webhook deliveries (one per client) +- **Per-Client Target DLQs**: SQS Dead Letter Queues for failed webhook deliveries (one per client target) ### Event Flow diff --git a/infrastructure/terraform/components/callbacks/README.md b/infrastructure/terraform/components/callbacks/README.md index f5056ab5..078afa54 100644 --- a/infrastructure/terraform/components/callbacks/README.md +++ b/infrastructure/terraform/components/callbacks/README.md @@ -15,10 +15,10 @@ |------|-------------|------|---------|:--------:| | [applications\_map\_parameter\_name](#input\_applications\_map\_parameter\_name) | SSM Parameter Store path for the clientId-to-applicationData map, where applicationData is currently only the applicationId | `string` | `null` | no | | [aws\_account\_id](#input\_aws\_account\_id) | The AWS Account ID (numeric) | `string` | n/a | yes | -| [clients](#input\_clients) | n/a |
list(object({
connection_name = string
destination_name = string
invocation_endpoint = string
invocation_rate_limit_per_second = optional(number, 10)
http_method = optional(string, "POST")
header_name = optional(string, "x-api-key")
header_value = string
client_detail = list(string)
}))
| `[]` | no | +| [client\_config\_bucket\_force\_destroy](#input\_client\_config\_bucket\_force\_destroy) | Force-delete all objects and versions from the client config bucket during destroy | `bool` | `false` | no | | [component](#input\_component) | The variable encapsulating the name of this component | `string` | `"callbacks"` | no | | [default\_tags](#input\_default\_tags) | A map of default tags to apply to all taggable resources within the component | `map(string)` | `{}` | no | -| [deploy\_mock\_webhook](#input\_deploy\_mock\_webhook) | Flag to deploy mock webhook lambda for integration testing (test/dev environments only) | `bool` | `false` | no | +| [deploy\_mock\_clients](#input\_deploy\_mock\_clients) | Flag to deploy mock webhook lambda for integration testing (test/dev environments only) | `bool` | `false` | no | | [enable\_event\_anomaly\_detection](#input\_enable\_event\_anomaly\_detection) | Enable CloudWatch anomaly detection alarm for inbound event queue message reception | `bool` | `true` | no | | [environment](#input\_environment) | The name of the tfscaffold environment | `string` | n/a | yes | | [event\_anomaly\_band\_width](#input\_event\_anomaly\_band\_width) | The width of the anomaly detection band. Higher values (e.g. 4-6) reduce sensitivity and noise, lower values (e.g. 2-3) increase sensitivity. Recommended: 2-4. | `number` | `3` | no | @@ -55,7 +55,6 @@ |------|-------------| | [deployment](#output\_deployment) | Deployment details used for post-deployment scripts | | [mock\_webhook\_lambda\_log\_group\_name](#output\_mock\_webhook\_lambda\_log\_group\_name) | CloudWatch log group name for mock webhook lambda (for integration test queries) | -| [mock\_webhook\_url](#output\_mock\_webhook\_url) | URL endpoint for mock webhook (for TEST\_WEBHOOK\_URL environment variable) | diff --git a/infrastructure/terraform/components/callbacks/_paths.sh b/infrastructure/terraform/components/callbacks/_paths.sh new file mode 100644 index 00000000..9b9aba00 --- /dev/null +++ b/infrastructure/terraform/components/callbacks/_paths.sh @@ -0,0 +1,8 @@ +_paths_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +repo_root="$(cd "${_paths_dir}/../../../.." && pwd)" +clients_dir="${repo_root}/infrastructure/terraform/modules/clients" + +# Follow symlinks to find the real nhs-notify-client-callbacks root +# (repo_root resolves to the workspace root, which differs in CI where the component is symlinked in) +_real_script="$(readlink -f "${BASH_SOURCE[0]}")" +bounded_context_root="$(cd "$(dirname "${_real_script}")/../../../.." && pwd)" diff --git a/infrastructure/terraform/components/callbacks/cloudwatch_metric_alarm_dlq_depth.tf b/infrastructure/terraform/components/callbacks/cloudwatch_metric_alarm_dlq_depth.tf index c38fb58f..e6ed2d9d 100644 --- a/infrastructure/terraform/components/callbacks/cloudwatch_metric_alarm_dlq_depth.tf +++ b/infrastructure/terraform/components/callbacks/cloudwatch_metric_alarm_dlq_depth.tf @@ -1,5 +1,5 @@ resource "aws_cloudwatch_metric_alarm" "client_dlq_depth" { - for_each = toset(keys(local.all_clients)) + for_each = toset(keys(local.config_targets)) alarm_name = "${local.csi}-${each.key}-dlq-depth" alarm_description = join(" ", [ @@ -25,7 +25,7 @@ resource "aws_cloudwatch_metric_alarm" "client_dlq_depth" { local.default_tags, { Name = "${local.csi}-${each.key}-dlq-depth" - Client = each.key + Client = local.config_targets[each.key].client_id }, ) } diff --git a/infrastructure/terraform/components/callbacks/locals.tf b/infrastructure/terraform/components/callbacks/locals.tf index b9f7d4d8..f4707154 100644 --- a/infrastructure/terraform/components/callbacks/locals.tf +++ b/infrastructure/terraform/components/callbacks/locals.tf @@ -4,29 +4,63 @@ locals { root_domain_name = "${var.environment}.${local.acct.route53_zone_names["client-callbacks"]}" # e.g. [main|dev|abxy0].smsnudge.[dev|nonprod|prod].nhsnotify.national.nhs.uk root_domain_id = local.acct.route53_zone_ids["client-callbacks"] - clients_by_name = { - for client in var.clients : - client.connection_name => client - } - - # Automatic test client when mock webhook is deployed - mock_client = var.deploy_mock_webhook ? { - "mock-client" = { - connection_name = "mock-client" - destination_name = "test-destination" - invocation_endpoint = aws_lambda_function_url.mock_webhook[0].function_url - invocation_rate_limit_per_second = 10 - http_method = "POST" - header_name = "x-api-key" - header_value = random_password.mock_webhook_api_key[0].result - client_detail = [ - "uk.nhs.notify.message.status.PUBLISHED.v1", - "uk.nhs.notify.channel.status.PUBLISHED.v1" + clients_dir_path = "${path.module}/../../modules/clients" + + config_clients = merge([ + for filename in fileset(local.clients_dir_path, "*.json") : { + (replace(filename, ".json", "")) = jsondecode(file("${local.clients_dir_path}/${filename}")) + } + ]...) + + # When deploying mock clients, replace sentinel placeholder values with the mock webhook URL and API key. + # Only used for S3 object content — must not be used as a for_each source (contains apply-time values). + enriched_mock_config_clients = var.deploy_mock_clients ? { + for client_id, client in local.config_clients : + client_id => merge(client, { + targets = [ + for target in try(client.targets, []) : + merge(target, { + invocationEndpoint = "${aws_lambda_function_url.mock_webhook[0].function_url}${target.targetId}" + apiKey = merge(target.apiKey, { headerValue = random_password.mock_webhook_api_key[0].result }) + }) ] + }) + } : local.config_clients + + + config_targets = merge([ + for client_id, data in local.config_clients : { + for target in try(data.targets, []) : target.targetId => { + client_id = client_id + target_id = target.targetId + invocation_endpoint = var.deploy_mock_clients ? "${aws_lambda_function_url.mock_webhook[0].function_url}${target.targetId}" : target.invocationEndpoint + invocation_rate_limit_per_second = target.invocationRateLimit + http_method = target.invocationMethod + header_name = target.apiKey.headerName + header_value = var.deploy_mock_clients ? random_password.mock_webhook_api_key[0].result : target.apiKey.headerValue + } + } + ]...) + + config_subscriptions = merge([ + for client_id, data in local.config_clients : { + for subscription in try(data.subscriptions, []) : subscription.subscriptionId => { + client_id = client_id + subscription_id = subscription.subscriptionId + target_ids = try(subscription.targetIds, []) + } } - } : {} + ]...) - all_clients = merge(local.clients_by_name, local.mock_client) + subscription_targets = merge([ + for subscription_id, subscription in local.config_subscriptions : { + for target_id in subscription.target_ids : + "${subscription_id}-${target_id}" => { + subscription_id = subscription_id + target_id = target_id + } + } + ]...) applications_map_parameter_name = coalesce(var.applications_map_parameter_name, "/${var.project}/${var.environment}/${var.component}/applications-map") } diff --git a/infrastructure/terraform/components/callbacks/module_client_destination.tf b/infrastructure/terraform/components/callbacks/module_client_destination.tf index 19f3c12f..21800e94 100644 --- a/infrastructure/terraform/components/callbacks/module_client_destination.tf +++ b/infrastructure/terraform/components/callbacks/module_client_destination.tf @@ -1,6 +1,5 @@ module "client_destination" { - source = "../../modules/client-destination" - for_each = local.all_clients + source = "../../modules/client-destination" project = var.project aws_account_id = var.aws_account_id @@ -11,16 +10,8 @@ module "client_destination" { kms_key_arn = module.kms.key_arn - connection_name = each.value.connection_name - destination_name = each.value.destination_name - invocation_endpoint = each.value.invocation_endpoint - invocation_rate_limit_per_second = each.value.invocation_rate_limit_per_second - http_method = each.value.http_method - header_name = each.value.header_name - header_value = each.value.header_value - client_detail = each.value.client_detail - - - + targets = local.config_targets + subscriptions = local.config_subscriptions + subscription_targets = local.subscription_targets } diff --git a/infrastructure/terraform/components/callbacks/module_mock_webhook_lambda.tf b/infrastructure/terraform/components/callbacks/module_mock_webhook_lambda.tf index 9a6de177..35f9a997 100644 --- a/infrastructure/terraform/components/callbacks/module_mock_webhook_lambda.tf +++ b/infrastructure/terraform/components/callbacks/module_mock_webhook_lambda.tf @@ -1,5 +1,5 @@ module "mock_webhook_lambda" { - count = var.deploy_mock_webhook ? 1 : 0 + count = var.deploy_mock_clients ? 1 : 0 source = "https://github.com/NHSDigital/nhs-notify-shared-modules/releases/download/3.0.6/terraform-lambda.zip" function_name = "mock-webhook" @@ -42,13 +42,13 @@ module "mock_webhook_lambda" { } resource "random_password" "mock_webhook_api_key" { - count = var.deploy_mock_webhook ? 1 : 0 + count = var.deploy_mock_clients ? 1 : 0 length = 32 special = false } data "aws_iam_policy_document" "mock_webhook_lambda" { - count = var.deploy_mock_webhook ? 1 : 0 + count = var.deploy_mock_clients ? 1 : 0 statement { sid = "KMSPermissions" @@ -67,7 +67,7 @@ data "aws_iam_policy_document" "mock_webhook_lambda" { # Lambda Function URL for mock webhook (test/dev only) resource "aws_lambda_function_url" "mock_webhook" { - count = var.deploy_mock_webhook ? 1 : 0 + count = var.deploy_mock_clients ? 1 : 0 function_name = module.mock_webhook_lambda[0].function_name authorization_type = "NONE" # Public endpoint for testing @@ -80,7 +80,7 @@ resource "aws_lambda_function_url" "mock_webhook" { } resource "aws_lambda_permission" "mock_webhook_function_url" { - count = var.deploy_mock_webhook ? 1 : 0 + count = var.deploy_mock_clients ? 1 : 0 statement_id_prefix = "FunctionURLAllowPublicAccess" action = "lambda:InvokeFunctionUrl" function_name = module.mock_webhook_lambda[0].function_name @@ -89,7 +89,7 @@ resource "aws_lambda_permission" "mock_webhook_function_url" { } resource "aws_lambda_permission" "mock_webhook_function_invoke" { - count = var.deploy_mock_webhook ? 1 : 0 + count = var.deploy_mock_clients ? 1 : 0 statement_id_prefix = "FunctionURLAllowInvokeAction" action = "lambda:InvokeFunction" function_name = module.mock_webhook_lambda[0].function_name diff --git a/infrastructure/terraform/components/callbacks/outputs.tf b/infrastructure/terraform/components/callbacks/outputs.tf index 3daaa8b2..1ca00df8 100644 --- a/infrastructure/terraform/components/callbacks/outputs.tf +++ b/infrastructure/terraform/components/callbacks/outputs.tf @@ -20,10 +20,5 @@ output "deployment" { output "mock_webhook_lambda_log_group_name" { description = "CloudWatch log group name for mock webhook lambda (for integration test queries)" - value = var.deploy_mock_webhook ? module.mock_webhook_lambda[0].cloudwatch_log_group_name : null -} - -output "mock_webhook_url" { - description = "URL endpoint for mock webhook (for TEST_WEBHOOK_URL environment variable)" - value = var.deploy_mock_webhook ? aws_lambda_function_url.mock_webhook[0].function_url : null + value = var.deploy_mock_clients ? module.mock_webhook_lambda[0].cloudwatch_log_group_name : null } diff --git a/infrastructure/terraform/components/callbacks/pipes_pipe_main.tf b/infrastructure/terraform/components/callbacks/pipes_pipe_main.tf index 6c088133..3fddfcca 100644 --- a/infrastructure/terraform/components/callbacks/pipes_pipe_main.tf +++ b/infrastructure/terraform/components/callbacks/pipes_pipe_main.tf @@ -25,9 +25,9 @@ resource "aws_pipes_pipe" "main" { input_template = <, - "transformedPayload": <$.transformedPayload>, - "headers": <$.headers> + "payload": <$.payload>, + "subscriptions": <$.subscriptions>, + "signatures": <$.signatures> } EOF } diff --git a/infrastructure/terraform/components/callbacks/pre.sh b/infrastructure/terraform/components/callbacks/pre.sh index 6f3957ec..873cdf5a 100755 --- a/infrastructure/terraform/components/callbacks/pre.sh +++ b/infrastructure/terraform/components/callbacks/pre.sh @@ -1,9 +1,32 @@ -# # This script is run before the Terraform apply command. -# # It ensures all Node.js dependencies are installed, generates any required dependencies, -# # and builds all Lambda functions in the workspace before Terraform provisions infrastructure. +# This script is run before the Terraform apply command. +# It ensures dependencies are installed, generates local client config files +# for terraform from S3-held subscriptions, and builds lambda workspaces. + +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# shellcheck source=_paths.sh +source "${script_dir}/_paths.sh" + +# Resolve deploy_mock_clients from tfvars; base_path/group/region/environment are in scope from terraform.sh +deploy_mock_clients="false" +for _tfvar_file in \ + "${base_path}/etc/group_${group}.tfvars" \ + "${base_path}/etc/env_${region}_${environment}.tfvars"; do + if [ -f "${_tfvar_file}" ]; then + _val=$(grep -E '^\s*deploy_mock_clients\s*=' "${_tfvar_file}" | tail -1 | sed 's/.*=\s*//;s/\s*$//') + [ -n "${_val}" ] && deploy_mock_clients="${_val}" + fi +done +echo "deploy_mock_clients resolved to: ${deploy_mock_clients}" npm ci npm run generate-dependencies --workspaces --if-present +"${script_dir}/sync-client-config.sh" + +if [ "${deploy_mock_clients}" == "true" ]; then + cp "${bounded_context_root}/tests/integration/fixtures/"*.json "${clients_dir}/" + echo "Copied mock client subscription config fixtures into clients dir" +fi + npm run lambda-build --workspaces --if-present diff --git a/infrastructure/terraform/components/callbacks/s3_bucket_client_config.tf b/infrastructure/terraform/components/callbacks/s3_bucket_client_config.tf index 58b016e4..533c7657 100644 --- a/infrastructure/terraform/components/callbacks/s3_bucket_client_config.tf +++ b/infrastructure/terraform/components/callbacks/s3_bucket_client_config.tf @@ -1,3 +1,16 @@ +resource "aws_s3_object" "mock_client_config" { + for_each = var.deploy_mock_clients ? toset(keys(local.config_clients)) : toset([]) + + bucket = module.client_config_bucket.id + key = "client_subscriptions/${local.config_clients[each.key].clientId}.json" + content = jsonencode(local.enriched_mock_config_clients[each.key]) + + kms_key_id = module.kms.key_arn + server_side_encryption = "aws:kms" + + content_type = "application/json" +} + module "client_config_bucket" { source = "https://github.com/NHSDigital/nhs-notify-shared-modules/releases/download/3.0.6/terraform-s3bucket.zip" @@ -17,7 +30,7 @@ module "client_config_bucket" { ) kms_key_arn = module.kms.key_arn - force_destroy = false + force_destroy = var.client_config_bucket_force_destroy versioning = true object_ownership = "BucketOwnerPreferred" bucket_key_enabled = true diff --git a/infrastructure/terraform/components/callbacks/ssm_parameter_applications_map.tf b/infrastructure/terraform/components/callbacks/ssm_parameter_applications_map.tf index 1e9b6925..567647d1 100644 --- a/infrastructure/terraform/components/callbacks/ssm_parameter_applications_map.tf +++ b/infrastructure/terraform/components/callbacks/ssm_parameter_applications_map.tf @@ -1,10 +1,16 @@ +resource "random_password" "mock_application_id" { + for_each = var.deploy_mock_clients ? toset(keys(local.config_clients)) : toset([]) + length = 24 + special = false +} + resource "aws_ssm_parameter" "applications_map" { name = local.applications_map_parameter_name type = "SecureString" key_id = module.kms.key_arn - value = var.deploy_mock_webhook ? jsonencode({ - "mock-client" = "mock-application-id" + value = var.deploy_mock_clients ? jsonencode({ + for id in keys(local.config_clients) : local.config_clients[id].clientId => random_password.mock_application_id[id].result }) : jsonencode({}) lifecycle { diff --git a/infrastructure/terraform/components/callbacks/sync-client-config.sh b/infrastructure/terraform/components/callbacks/sync-client-config.sh new file mode 100755 index 00000000..2c2a3ecb --- /dev/null +++ b/infrastructure/terraform/components/callbacks/sync-client-config.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash + +# Seeds local client subscription JSON files from S3 into modules/clients/ before Terraform runs. +# Terraform reads those files via fileset() to build local.config_clients. +# On first apply the bucket may not exist yet; this is handled gracefully. + +set -euo pipefail + +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# shellcheck source=_paths.sh +source "${script_dir}/_paths.sh" + +: "${ENVIRONMENT:?ENVIRONMENT must be set}" +: "${AWS_REGION:?AWS_REGION must be set}" +: "${AWS_ACCOUNT_ID:?AWS_ACCOUNT_ID must be set}" + +cd "${repo_root}" + +rm -f "${clients_dir}"/*.json + +bucket_name="nhs-${AWS_ACCOUNT_ID}-${AWS_REGION}-${ENVIRONMENT}-callbacks-subscription-config" + +s3_prefix="client_subscriptions/" + +echo "Seeding client configs from s3://${bucket_name}/${s3_prefix} for ${ENVIRONMENT}/${AWS_REGION}" + +if ! sync_output=$(aws s3 sync "s3://${bucket_name}/${s3_prefix}" "${clients_dir}/" \ + --region "${AWS_REGION}" \ + --exclude "*" \ + --include "*.json" \ + --only-show-errors 2>&1); then + if [[ "${sync_output}" == *"NoSuchBucket"* ]]; then + # Expected on first apply before Terraform creates the bucket. + echo "Client config bucket not found yet; skipping sync for first run" + else + echo "Failed to sync client config from S3" >&2 + echo "${sync_output}" >&2 + exit 1 + fi +fi + +# Ensure an empty directory produces a zero-length array rather than a literal "*.json" entry. +shopt -s nullglob +seeded_files=("${clients_dir}"/*.json) +seeded_count="${#seeded_files[@]}" +shopt -u nullglob + +echo "Seeded ${seeded_count} client config file(s)" diff --git a/infrastructure/terraform/components/callbacks/variables.tf b/infrastructure/terraform/components/callbacks/variables.tf index b82546b0..ea06ba23 100644 --- a/infrastructure/terraform/components/callbacks/variables.tf +++ b/infrastructure/terraform/components/callbacks/variables.tf @@ -87,21 +87,7 @@ variable "pipe_event_patterns" { default = [] } -variable "clients" { - type = list(object({ - connection_name = string - destination_name = string - invocation_endpoint = string - invocation_rate_limit_per_second = optional(number, 10) - http_method = optional(string, "POST") - header_name = optional(string, "x-api-key") - header_value = string - client_detail = list(string) - })) - default = [] - -} variable "pipe_log_level" { type = string @@ -163,12 +149,18 @@ variable "event_anomaly_band_width" { } } -variable "deploy_mock_webhook" { +variable "deploy_mock_clients" { type = bool description = "Flag to deploy mock webhook lambda for integration testing (test/dev environments only)" default = false } +variable "client_config_bucket_force_destroy" { + type = bool + description = "Force-delete all objects and versions from the client config bucket during destroy" + default = false +} + variable "message_root_uri" { type = string description = "The root URI used for constructing message links in callback payloads" diff --git a/infrastructure/terraform/modules/client-destination/README.md b/infrastructure/terraform/modules/client-destination/README.md index 1cbd4706..31dd9fac 100644 --- a/infrastructure/terraform/modules/client-destination/README.md +++ b/infrastructure/terraform/modules/client-destination/README.md @@ -11,19 +11,14 @@ No requirements. |------|-------------|------|---------|:--------:| | [aws\_account\_id](#input\_aws\_account\_id) | Account ID | `string` | n/a | yes | | [client\_bus\_name](#input\_client\_bus\_name) | EventBus name where you create the rule | `string` | n/a | yes | -| [client\_detail](#input\_client\_detail) | Client Event Detail | `list(string)` | n/a | yes | | [component](#input\_component) | Component name | `string` | n/a | yes | -| [connection\_name](#input\_connection\_name) | Connection name | `string` | n/a | yes | -| [destination\_name](#input\_destination\_name) | Destination Name | `string` | n/a | yes | | [environment](#input\_environment) | The name of the tfscaffold environment | `string` | n/a | yes | -| [header\_name](#input\_header\_name) | Header name | `string` | n/a | yes | -| [header\_value](#input\_header\_value) | Header value | `string` | n/a | yes | -| [http\_method](#input\_http\_method) | HTTP Method | `string` | n/a | yes | -| [invocation\_endpoint](#input\_invocation\_endpoint) | Invocation Endpoint | `string` | n/a | yes | -| [invocation\_rate\_limit\_per\_second](#input\_invocation\_rate\_limit\_per\_second) | Invocation Rate Limit Per Second | `string` | n/a | yes | | [kms\_key\_arn](#input\_kms\_key\_arn) | KMS Key ARN | `string` | n/a | yes | | [project](#input\_project) | The name of the tfscaffold project | `string` | n/a | yes | | [region](#input\_region) | AWS Region | `string` | n/a | yes | +| [subscription\_targets](#input\_subscription\_targets) | Flattened subscription-target fanout map keyed by subscription-target composite key |
map(object({
subscription_id = string
target_id = string
}))
| n/a | yes | +| [subscriptions](#input\_subscriptions) | Flattened subscription definitions keyed by subscription\_id |
map(object({
client_id = string
subscription_id = string
target_ids = list(string)
}))
| n/a | yes | +| [targets](#input\_targets) | Flattened target definitions keyed by target\_id |
map(object({
client_id = string
target_id = string
invocation_endpoint = string
invocation_rate_limit_per_second = number
http_method = string
header_name = string
header_value = string
}))
| n/a | yes | ## Modules | Name | Source | Version | diff --git a/infrastructure/terraform/modules/client-destination/cloudwatch_event_api_destination_this.tf b/infrastructure/terraform/modules/client-destination/cloudwatch_event_api_destination_this.tf index 53499d92..4bec92cc 100644 --- a/infrastructure/terraform/modules/client-destination/cloudwatch_event_api_destination_this.tf +++ b/infrastructure/terraform/modules/client-destination/cloudwatch_event_api_destination_this.tf @@ -1,8 +1,10 @@ -resource "aws_cloudwatch_event_api_destination" "main" { - name = "${local.csi}-${var.destination_name}" - description = "API Destination for ${var.destination_name}" - invocation_endpoint = var.invocation_endpoint - http_method = var.http_method - invocation_rate_limit_per_second = var.invocation_rate_limit_per_second - connection_arn = aws_cloudwatch_event_connection.main.arn +resource "aws_cloudwatch_event_api_destination" "per_target" { + for_each = var.targets + + name = "${local.csi}-${each.key}" + description = "API Destination for ${each.key}" + invocation_endpoint = each.value.invocation_endpoint + http_method = each.value.http_method + invocation_rate_limit_per_second = each.value.invocation_rate_limit_per_second + connection_arn = aws_cloudwatch_event_connection.per_target[each.key].arn } diff --git a/infrastructure/terraform/modules/client-destination/cloudwatch_event_connection_main.tf b/infrastructure/terraform/modules/client-destination/cloudwatch_event_connection_main.tf index 7136d70b..7546d666 100644 --- a/infrastructure/terraform/modules/client-destination/cloudwatch_event_connection_main.tf +++ b/infrastructure/terraform/modules/client-destination/cloudwatch_event_connection_main.tf @@ -1,12 +1,14 @@ -resource "aws_cloudwatch_event_connection" "main" { - name = "${local.csi}-${var.connection_name}" - description = "Event Connection which would be used by API Destination ${var.connection_name}" +resource "aws_cloudwatch_event_connection" "per_target" { + for_each = var.targets + + name = "${local.csi}-${each.key}" + description = "Event Connection which would be used by API Destination ${each.key}" authorization_type = "API_KEY" auth_parameters { api_key { - key = var.header_name - value = var.header_value + key = each.value.header_name + value = each.value.header_value } } } diff --git a/infrastructure/terraform/modules/client-destination/cloudwatch_event_rule_main.tf b/infrastructure/terraform/modules/client-destination/cloudwatch_event_rule_main.tf index 4bce1003..bdf7ea47 100644 --- a/infrastructure/terraform/modules/client-destination/cloudwatch_event_rule_main.tf +++ b/infrastructure/terraform/modules/client-destination/cloudwatch_event_rule_main.tf @@ -1,30 +1,41 @@ -resource "aws_cloudwatch_event_rule" "main" { - name = "${local.csi}-${var.connection_name}" - description = "Client Callbacks event rule for inbound events" +resource "aws_cloudwatch_event_rule" "per_subscription" { + for_each = var.subscriptions + + name = "${local.csi}-${each.key}" + description = "Client Callbacks event rule for subscription ${each.key}" event_bus_name = var.client_bus_name event_pattern = jsonencode({ "detail" : { - "type" : var.client_detail + "subscriptions" : [each.value.subscription_id] } }) } -resource "aws_cloudwatch_event_target" "main" { - rule = aws_cloudwatch_event_rule.main.name - arn = aws_cloudwatch_event_api_destination.main.arn - target_id = "${local.csi}-${var.connection_name}" +resource "aws_cloudwatch_event_target" "per_subscription_target" { + for_each = var.subscription_targets + + rule = aws_cloudwatch_event_rule.per_subscription[each.value.subscription_id].name + arn = aws_cloudwatch_event_api_destination.per_target[each.value.target_id].arn + target_id = "${local.csi}-${each.value.target_id}" role_arn = aws_iam_role.api_target_role.arn event_bus_name = var.client_bus_name - input_path = "$.detail.transformedPayload" dead_letter_config { - arn = module.target_dlq.sqs_queue_arn + arn = module.target_dlq[each.value.target_id].sqs_queue_arn + } + + input_transformer { + input_paths = { + data = "$.detail.payload.data" + } + + input_template = "{\"data\": }" } http_target { header_parameters = { - "x-hmac-sha256-signature" = "$.detail.headers.x-hmac-sha256-signature" + "x-hmac-sha256-signature" = "$.detail.signatures.${replace(each.value.target_id, "-", "_")}" } } diff --git a/infrastructure/terraform/modules/client-destination/iam_role_api_target_role.tf b/infrastructure/terraform/modules/client-destination/iam_role_api_target_role.tf index 92c16aaa..bcab3490 100644 --- a/infrastructure/terraform/modules/client-destination/iam_role_api_target_role.tf +++ b/infrastructure/terraform/modules/client-destination/iam_role_api_target_role.tf @@ -31,7 +31,7 @@ resource "aws_iam_policy" "api_target_role" { data "aws_iam_policy_document" "api_target_role" { statement { - sid = replace("AllowAPIDestinationAccessFor${var.connection_name}", "-", "") + sid = "AllowAPIDestinationAccess" effect = "Allow" actions = [ @@ -39,12 +39,13 @@ data "aws_iam_policy_document" "api_target_role" { ] resources = [ - aws_cloudwatch_event_api_destination.main.arn + for destination in aws_cloudwatch_event_api_destination.per_target : + destination.arn ] } statement { - sid = replace("AllowSQSSendMessageForDLQFor${var.connection_name}", "-", "") + sid = "AllowSQSSendMessageForDLQ" effect = "Allow" actions = [ @@ -52,12 +53,13 @@ data "aws_iam_policy_document" "api_target_role" { ] resources = [ - module.target_dlq.sqs_queue_arn, + for dlq in module.target_dlq : + dlq.sqs_queue_arn ] } statement { - sid = replace("AllowKMSForDLQFor${var.connection_name}", "-", "") + sid = "AllowKMSForDLQ" effect = "Allow" actions = [ diff --git a/infrastructure/terraform/modules/client-destination/module_target_dlq.tf b/infrastructure/terraform/modules/client-destination/module_target_dlq.tf index 5a1457e5..3e2cd83b 100644 --- a/infrastructure/terraform/modules/client-destination/module_target_dlq.tf +++ b/infrastructure/terraform/modules/client-destination/module_target_dlq.tf @@ -1,12 +1,13 @@ module "target_dlq" { - source = "https://github.com/NHSDigital/nhs-notify-shared-modules/releases/download/3.0.6/terraform-sqs.zip" + source = "https://github.com/NHSDigital/nhs-notify-shared-modules/releases/download/3.0.6/terraform-sqs.zip" + for_each = var.targets aws_account_id = var.aws_account_id component = var.component environment = var.environment project = var.project region = var.region - name = "${var.connection_name}-dlq" + name = "${each.key}-dlq" sqs_kms_key_arn = var.kms_key_arn @@ -14,10 +15,12 @@ module "target_dlq" { create_dlq = false - sqs_policy_overload = data.aws_iam_policy_document.target_dlq.json + sqs_policy_overload = data.aws_iam_policy_document.target_dlq[each.key].json } data "aws_iam_policy_document" "target_dlq" { + for_each = var.targets + statement { sid = "AllowEventBridgeToSendMessage" effect = "Allow" @@ -32,7 +35,7 @@ data "aws_iam_policy_document" "target_dlq" { ] resources = [ - "arn:aws:sqs:${var.region}:${var.aws_account_id}:${var.project}-${var.environment}-${var.component}-${var.connection_name}-dlq-queue" + "arn:aws:sqs:${var.region}:${var.aws_account_id}:${var.project}-${var.environment}-${var.component}-${each.key}-dlq-queue" ] } } diff --git a/infrastructure/terraform/modules/client-destination/variables.tf b/infrastructure/terraform/modules/client-destination/variables.tf index a5360104..2b9a0ceb 100644 --- a/infrastructure/terraform/modules/client-destination/variables.tf +++ b/infrastructure/terraform/modules/client-destination/variables.tf @@ -23,44 +23,37 @@ variable "region" { description = "AWS Region" } -variable "connection_name" { - type = string - description = "Connection name" -} - -variable "header_name" { - type = string - description = "Header name" -} - -variable "header_value" { - type = string - description = "Header value" -} +variable "targets" { + type = map(object({ + client_id = string + target_id = string + invocation_endpoint = string + invocation_rate_limit_per_second = number + http_method = string + header_name = string + header_value = string + })) -variable "destination_name" { - type = string - description = "Destination Name" + description = "Flattened target definitions keyed by target_id" } -variable "invocation_endpoint" { - type = string - description = "Invocation Endpoint" -} +variable "subscriptions" { + type = map(object({ + client_id = string + subscription_id = string + target_ids = list(string) + })) -variable "invocation_rate_limit_per_second" { - type = string - description = "Invocation Rate Limit Per Second" + description = "Flattened subscription definitions keyed by subscription_id" } -variable "http_method" { - type = string - description = "HTTP Method" -} +variable "subscription_targets" { + type = map(object({ + subscription_id = string + target_id = string + })) -variable "client_detail" { - type = list(string) - description = "Client Event Detail" + description = "Flattened subscription-target fanout map keyed by subscription-target composite key" } variable "client_bus_name" { diff --git a/infrastructure/terraform/modules/clients/README.md b/infrastructure/terraform/modules/clients/README.md new file mode 100644 index 00000000..df8c1f5c --- /dev/null +++ b/infrastructure/terraform/modules/clients/README.md @@ -0,0 +1,19 @@ + + + + +## Requirements + +No requirements. +## Inputs + +No inputs. +## Modules + +No modules. +## Outputs + +No outputs. + + + diff --git a/lambdas/client-transform-filter-lambda/src/__tests__/index.component.test.ts b/lambdas/client-transform-filter-lambda/src/__tests__/index.component.test.ts index c524ef3c..b46c49f8 100644 --- a/lambdas/client-transform-filter-lambda/src/__tests__/index.component.test.ts +++ b/lambdas/client-transform-filter-lambda/src/__tests__/index.component.test.ts @@ -150,7 +150,10 @@ describe("Lambda handler with S3 subscription filtering", () => { expect(mockSend.mock.calls[0][0]).toBeInstanceOf(GetObjectCommand); expect(mockSsmSend).toHaveBeenCalledTimes(1); expect(mockSsmSend.mock.calls[0][0]).toBeInstanceOf(GetParameterCommand); - expect(result[0].headers["x-hmac-sha256-signature"]).toMatch(/^[0-9a-f]+$/); + expect(result[0]).toHaveProperty("payload"); + expect(result[0]).toHaveProperty("subscriptions"); + expect(result[0]).toHaveProperty("signatures"); + expect(Object.values(result[0].signatures)[0]).toMatch(/^[0-9a-f]+$/); }); it("filters out event when status is not in subscription", async () => { @@ -200,8 +203,10 @@ describe("Lambda handler with S3 subscription filtering", () => { // Only the DELIVERED event passes the filter expect(result).toHaveLength(1); - expect((result[0].data as { messageStatus: string }).messageStatus).toBe( - "DELIVERED", + expect(result[0].payload.data[0].attributes).toEqual( + expect.objectContaining({ + messageStatus: "delivered", + }), ); }); diff --git a/lambdas/client-transform-filter-lambda/src/__tests__/index.test.ts b/lambdas/client-transform-filter-lambda/src/__tests__/index.test.ts index 897e728f..5c6d495c 100644 --- a/lambdas/client-transform-filter-lambda/src/__tests__/index.test.ts +++ b/lambdas/client-transform-filter-lambda/src/__tests__/index.test.ts @@ -171,8 +171,10 @@ describe("Lambda handler", () => { const result = await handler([sqsMessage]); expect(result).toHaveLength(1); - expect(result[0]).toHaveProperty("transformedPayload"); - const dataItem = result[0].transformedPayload.data[0]; + expect(result[0]).toHaveProperty("payload"); + expect(result[0]).toHaveProperty("subscriptions"); + expect(result[0]).toHaveProperty("signatures"); + const dataItem = result[0].payload.data[0]; expect(dataItem.type).toBe("MessageStatus"); expect((dataItem.attributes as MessageStatusAttributes).messageStatus).toBe( "delivered", @@ -232,6 +234,65 @@ describe("Lambda handler", () => { ); }); + it("should throw when any target is missing an apiKey", async () => { + const customConfigLoader = { + loadClientConfig: jest.fn().mockResolvedValue( + createClientSubscriptionConfig("client-abc-123", { + subscriptions: [ + createMessageStatusSubscription(["DELIVERED"], { + targetIds: ["target-no-key", DEFAULT_TARGET_ID], + }), + ], + targets: [ + createTarget({ + targetId: "target-no-key", + apiKey: undefined as unknown as { + headerName: string; + headerValue: string; + }, + }), + createTarget({ + targetId: DEFAULT_TARGET_ID, + apiKey: { + headerName: "x-api-key", + headerValue: "valid-key", + }, + }), + ], + }), + ), + } as unknown as ConfigLoader; + + const handlerWithMixedTargets = createHandler({ + createObservabilityService: () => + new ObservabilityService(mockLogger, mockMetrics, mockMetricsLogger), + createConfigLoaderService: () => + ({ getLoader: () => customConfigLoader }) as ConfigLoaderService, + createApplicationsMapService: makeStubApplicationsMapService, + }); + + const sqsMessage: SQSRecord = { + messageId: "sqs-msg-id-mixed", + receiptHandle: "receipt-handle-mixed", + body: JSON.stringify(validMessageStatusEvent), + attributes: { + ApproximateReceiveCount: "1", + SentTimestamp: "1519211230", + SenderId: "ABCDEFGHIJ", + ApproximateFirstReceiveTimestamp: "1519211230", + }, + messageAttributes: {}, + md5OfBody: "mock-md5", + eventSource: "aws:sqs", + eventSourceARN: "arn:aws:sqs:eu-west-2:123456789:mock-queue", + awsRegion: "eu-west-2", + }; + + await expect(handlerWithMixedTargets([sqsMessage])).rejects.toThrow( + "Missing apiKey for target target-no-key", + ); + }); + it("should handle batch of SQS messages from EventBridge Pipes", async () => { const sqsMessages: SQSRecord[] = [ { @@ -271,8 +332,8 @@ describe("Lambda handler", () => { const result = await handler(sqsMessages); expect(result).toHaveLength(2); - expect(result[0]).toHaveProperty("transformedPayload"); - expect(result[1]).toHaveProperty("transformedPayload"); + expect(result[0]).toHaveProperty("payload"); + expect(result[1]).toHaveProperty("payload"); }); it("should reject event with unsupported type before reaching transformer", async () => { @@ -351,8 +412,10 @@ describe("Lambda handler", () => { const result = await handler([sqsMessage]); expect(result).toHaveLength(1); - expect(result[0]).toHaveProperty("transformedPayload"); - const dataItem = result[0].transformedPayload.data[0]; + expect(result[0]).toHaveProperty("payload"); + expect(result[0]).toHaveProperty("subscriptions"); + expect(result[0]).toHaveProperty("signatures"); + const dataItem = result[0].payload.data[0]; expect(dataItem.type).toBe("ChannelStatus"); expect((dataItem.attributes as ChannelStatusAttributes).channelStatus).toBe( "delivered", @@ -513,8 +576,8 @@ describe("Lambda handler", () => { const result = await handler(sqsMessages); expect(result).toHaveLength(2); - expect(result[0].transformedPayload.data[0].type).toBe("MessageStatus"); - expect(result[1].transformedPayload.data[0].type).toBe("ChannelStatus"); + expect(result[0].payload.data[0].type).toBe("MessageStatus"); + expect(result[1].payload.data[0].type).toBe("ChannelStatus"); }); }); diff --git a/lambdas/client-transform-filter-lambda/src/__tests__/services/subscription-filter.test.ts b/lambdas/client-transform-filter-lambda/src/__tests__/services/subscription-filter.test.ts index c302fda1..153ab934 100644 --- a/lambdas/client-transform-filter-lambda/src/__tests__/services/subscription-filter.test.ts +++ b/lambdas/client-transform-filter-lambda/src/__tests__/services/subscription-filter.test.ts @@ -113,6 +113,7 @@ describe("evaluateSubscriptionFilters", () => { matched: true, subscriptionType: "MessageStatus", targetIds: ["00000000-0000-4000-8000-000000000001"], + subscriptionIds: ["00000000-0000-0000-0000-000000000001"], }); }); @@ -128,14 +129,16 @@ describe("evaluateSubscriptionFilters", () => { }); }); - it("returns only matched subscription target IDs", () => { + it("returns only matched subscription target IDs and subscription IDs", () => { const event = createMessageStatusEvent("client-1", "DELIVERED"); const config = createClientSubscriptionConfig("client-1", { subscriptions: [ createMessageStatusSubscription(["DELIVERED"], { + subscriptionId: "sub-a", targetIds: ["target-a"], }), createMessageStatusSubscription(["FAILED"], { + subscriptionId: "sub-b", targetIds: ["target-b"], }), ], @@ -147,6 +150,7 @@ describe("evaluateSubscriptionFilters", () => { matched: true, subscriptionType: "MessageStatus", targetIds: ["target-a"], + subscriptionIds: ["sub-a"], }); }); }); @@ -174,6 +178,7 @@ describe("evaluateSubscriptionFilters", () => { matched: true, subscriptionType: "ChannelStatus", targetIds: ["00000000-0000-4000-8000-000000000001"], + subscriptionIds: ["00000000-0000-0000-0000-000000000002"], }); }); @@ -201,7 +206,7 @@ describe("evaluateSubscriptionFilters", () => { }); }); - it("returns only matched channel subscription target IDs", () => { + it("returns only matched channel subscription target IDs and subscription IDs", () => { const event = createChannelStatusEvent( "client-1", "SMS", @@ -217,6 +222,7 @@ describe("evaluateSubscriptionFilters", () => { ["delivered"], "EMAIL", { + subscriptionId: "sub-email", targetIds: ["target-email"], }, ), @@ -225,6 +231,7 @@ describe("evaluateSubscriptionFilters", () => { ["permanent_failure"], "SMS", { + subscriptionId: "sub-sms", targetIds: ["target-sms"], }, ), @@ -237,6 +244,7 @@ describe("evaluateSubscriptionFilters", () => { matched: true, subscriptionType: "ChannelStatus", targetIds: ["target-sms"], + subscriptionIds: ["sub-sms"], }); }); }); diff --git a/lambdas/client-transform-filter-lambda/src/handler.ts b/lambdas/client-transform-filter-lambda/src/handler.ts index 4c7ae933..0d1f20b6 100644 --- a/lambdas/client-transform-filter-lambda/src/handler.ts +++ b/lambdas/client-transform-filter-lambda/src/handler.ts @@ -22,9 +22,25 @@ type UnsignedEvent = StatusPublishEvent & { transformedPayload: ClientCallbackPayload; }; -export interface TransformedEvent extends StatusPublishEvent { - transformedPayload: ClientCallbackPayload; - headers: { "x-hmac-sha256-signature": string }; +type FilteredEvent = UnsignedEvent & { + subscriptionIds: string[]; + targetIds: string[]; +}; + +type SignedEvent = { + transformedEvent: TransformedEvent; + deliveryContext: { + correlationId: string; + eventType: string; + clientId: string; + messageId: string; + }; +}; + +export interface TransformedEvent { + payload: ClientCallbackPayload; + subscriptions: string[]; + signatures: Record; } class BatchStats { @@ -125,17 +141,17 @@ function processSingleEvent( type ClientConfigMap = Map; async function signBatch( - filteredEvents: UnsignedEvent[], + filteredEvents: FilteredEvent[], applicationsMapService: ApplicationsMapService, configByClientId: ClientConfigMap, stats: BatchStats, observability: ObservabilityService, -): Promise { +): Promise { const results = await pMap( filteredEvents, - async (event): Promise => { + async (event): Promise => { const { clientId } = event.data; - const correlationId = extractCorrelationId(event); + const correlationId = extractCorrelationId(event) ?? event.id; const applicationId = await applicationsMapService.getApplicationId(clientId); @@ -149,53 +165,52 @@ async function signBatch( } const clientConfig = configByClientId.get(clientId); - const apiKey = clientConfig?.targets?.[0]?.apiKey?.headerValue; - if (!apiKey) { - stats.recordFiltered(); - logger.warn( - "No apiKey in client config - event will not be delivered", - { clientId, correlationId }, + const targetsById = new Map( + (clientConfig?.targets ?? []).map((t) => [t.targetId, t]), + ); + + const signaturesByTarget = new Map(); + + for (const targetId of event.targetIds) { + const target = targetsById.get(targetId); + const apiKey = target?.apiKey?.headerValue; + if (!apiKey) { + throw new ValidationError( + `Missing apiKey for target ${targetId}`, + correlationId, + ); + } + const signature = signPayload( + event.transformedPayload, + applicationId, + apiKey, + ); + signaturesByTarget.set(targetId.replaceAll("-", "_"), signature); + observability.recordCallbackSigned( + event.transformedPayload, + correlationId, + clientId, + signature, ); - return undefined; } - const signature = signPayload( - event.transformedPayload, - applicationId, - apiKey, - ); - const signedEvent: TransformedEvent = { - ...event, - headers: { "x-hmac-sha256-signature": signature }, + return { + transformedEvent: { + payload: event.transformedPayload, + subscriptions: event.subscriptionIds, + signatures: Object.fromEntries(signaturesByTarget), + }, + deliveryContext: { + correlationId, + eventType: event.type, + clientId, + messageId: event.data.messageId, + }, }; - observability.recordCallbackSigned( - signedEvent.transformedPayload, - correlationId, - clientId, - signature, - ); - return signedEvent; }, { concurrency: BATCH_CONCURRENCY }, ); - return results.filter((e): e is TransformedEvent => e !== undefined); -} - -function recordDeliveryInitiated( - transformedEvents: TransformedEvent[], - observability: ObservabilityService, -): void { - for (const transformedEvent of transformedEvents) { - const { clientId, messageId } = transformedEvent.data; - const correlationId = extractCorrelationId(transformedEvent); - - observability.recordDeliveryInitiated({ - correlationId, - eventType: transformedEvent.type, - clientId, - messageId, - }); - } + return results.filter((e): e is SignedEvent => e !== undefined); } async function loadClientConfigs( @@ -219,10 +234,10 @@ async function filterBatch( configByClientId: ClientConfigMap, observability: ObservabilityService, stats: BatchStats, -): Promise { +): Promise { observability.recordFilteringStarted({ batchSize: transformedEvents.length }); - const filtered: UnsignedEvent[] = []; + const filtered: FilteredEvent[] = []; for (const event of transformedEvents) { const { clientId } = event.data; @@ -231,7 +246,11 @@ async function filterBatch( const filterResult = evaluateSubscriptionFilters(event, config); if (filterResult.matched) { - filtered.push(event); + filtered.push({ + ...event, + subscriptionIds: filterResult.subscriptionIds ?? [], + targetIds: filterResult.targetIds ?? [], + }); observability.recordFilteringMatched({ correlationId, clientId, @@ -313,6 +332,14 @@ export async function processEvents( observability, ); + for (const signedEvent of signedEvents) { + observability.recordDeliveryInitiated(signedEvent.deliveryContext); + } + + const deliverableEvents = signedEvents.map( + (signedEvent) => signedEvent.transformedEvent, + ); + const processingTime = Date.now() - startTime; observability.logBatchProcessingCompleted({ ...stats.toObject(), @@ -320,10 +347,8 @@ export async function processEvents( processingTimeMs: processingTime, }); - recordDeliveryInitiated(signedEvents, observability); - await observability.flush(); - return signedEvents; + return deliverableEvents; } catch (error) { stats.recordFailure(); diff --git a/lambdas/client-transform-filter-lambda/src/services/subscription-filter.ts b/lambdas/client-transform-filter-lambda/src/services/subscription-filter.ts index 55c131c7..2a51627f 100644 --- a/lambdas/client-transform-filter-lambda/src/services/subscription-filter.ts +++ b/lambdas/client-transform-filter-lambda/src/services/subscription-filter.ts @@ -14,6 +14,7 @@ type FilterResult = { matched: boolean; subscriptionType: "MessageStatus" | "ChannelStatus" | "Unknown"; targetIds?: string[]; + subscriptionIds?: string[]; }; const unique = (values: string[]): string[] => [...new Set(values)]; @@ -30,48 +31,60 @@ export const evaluateSubscriptionFilters = ( } if (event.type === EventTypes.MESSAGE_STATUS_PUBLISHED) { - const typedEvent = event as StatusPublishEvent; + const matchingSubscriptions = config.subscriptions.filter((subscription) => + matchesMessageStatusSubscription( + { + ...config, + subscriptions: [subscription], + }, + event as StatusPublishEvent, + ), + ); const matchingTargetIds = unique( - config.subscriptions - .filter((subscription) => - matchesMessageStatusSubscription( - { - ...config, - subscriptions: [subscription], - }, - typedEvent, - ), - ) - .flatMap((subscription) => subscription.targetIds), + matchingSubscriptions.flatMap((subscription) => subscription.targetIds), + ); + const matchingSubscriptionIds = unique( + matchingSubscriptions.map((subscription) => subscription.subscriptionId), ); return { matched: matchingTargetIds.length > 0, subscriptionType: "MessageStatus", - ...(matchingTargetIds.length > 0 ? { targetIds: matchingTargetIds } : {}), + ...(matchingTargetIds.length > 0 + ? { + targetIds: matchingTargetIds, + subscriptionIds: matchingSubscriptionIds, + } + : {}), }; } if (event.type === EventTypes.CHANNEL_STATUS_PUBLISHED) { - const typedEvent = event as StatusPublishEvent; + const matchingSubscriptions = config.subscriptions.filter((subscription) => + matchesChannelStatusSubscription( + { + ...config, + subscriptions: [subscription], + }, + event as StatusPublishEvent, + ), + ); const matchingTargetIds = unique( - config.subscriptions - .filter((subscription) => - matchesChannelStatusSubscription( - { - ...config, - subscriptions: [subscription], - }, - typedEvent, - ), - ) - .flatMap((subscription) => subscription.targetIds), + matchingSubscriptions.flatMap((subscription) => subscription.targetIds), + ); + const matchingSubscriptionIds = unique( + matchingSubscriptions.map((subscription) => subscription.subscriptionId), ); return { matched: matchingTargetIds.length > 0, subscriptionType: "ChannelStatus", - ...(matchingTargetIds.length > 0 ? { targetIds: matchingTargetIds } : {}), + ...(matchingTargetIds.length > 0 + ? { + targetIds: matchingTargetIds, + subscriptionIds: matchingSubscriptionIds, + } + : {}), }; } diff --git a/lambdas/mock-webhook-lambda/jest.config.ts b/lambdas/mock-webhook-lambda/jest.config.ts index 571b3a87..017ed2db 100644 --- a/lambdas/mock-webhook-lambda/jest.config.ts +++ b/lambdas/mock-webhook-lambda/jest.config.ts @@ -5,6 +5,7 @@ export default { coverageThreshold: { global: { ...nodeJestConfig.coverageThreshold?.global, + branches: 100, lines: 100, statements: 100, }, diff --git a/lambdas/mock-webhook-lambda/src/__tests__/index.test.ts b/lambdas/mock-webhook-lambda/src/__tests__/index.test.ts index 0cf93f1b..056dc397 100644 --- a/lambdas/mock-webhook-lambda/src/__tests__/index.test.ts +++ b/lambdas/mock-webhook-lambda/src/__tests__/index.test.ts @@ -28,8 +28,9 @@ const DEFAULT_HEADERS = { const createMockEvent = ( body: string | null, headers: Record = DEFAULT_HEADERS, + rawPath?: string, ): APIGatewayProxyEvent => - ({ body, headers }) as unknown as APIGatewayProxyEvent; + ({ body, headers, rawPath }) as unknown as APIGatewayProxyEvent; describe("Mock Webhook Lambda", () => { beforeAll(() => { @@ -86,12 +87,20 @@ describe("Mock Webhook Lambda", () => { ], }; - const event = createMockEvent(JSON.stringify(callback)); + const event = createMockEvent( + JSON.stringify(callback), + DEFAULT_HEADERS, + "/target-abc-123", + ); const result = await handler(event); expect(result.statusCode).toBe(200); const body = JSON.parse(result.body); expect(body.message).toBe("Callback received"); + expect(mockLogger.info).toHaveBeenCalledWith( + "Callback received", + expect.objectContaining({ path: "/target-abc-123" }), + ); }); it("should accept and log ChannelStatus callback", async () => { @@ -328,7 +337,7 @@ describe("Mock Webhook Lambda", () => { }); describe("Logging", () => { - it("should log callback with structured format including messageId", async () => { + it("should log Callback received with structured context", async () => { const callback = { data: [ { @@ -347,40 +356,25 @@ describe("Mock Webhook Lambda", () => { ], }; - const event = createMockEvent(JSON.stringify(callback)); - await handler(event); - - const callbackCall = mockLogger.info.mock.calls.find( - ([message]: [string]) => - typeof message === "string" && message.startsWith("CALLBACK"), - ); - - expect(callbackCall).toBeDefined(); - const [message, context] = callbackCall as [ - string, - Record, - ]; - expect(message).toContain("some-idempotency-key"); - expect(message).toContain("MessageStatus"); - expect(context).toMatchObject({ - correlationId: "some-idempotency-key", - messageId: "test-msg-789", - messageType: "MessageStatus", + const event = createMockEvent(JSON.stringify(callback), { + ...DEFAULT_HEADERS, + "x-hmac-sha256-signature": "test-sig", }); + await handler(event); const receivedCall = mockLogger.info.mock.calls.find( ([msg]: [string]) => msg === "Callback received", ); + expect(receivedCall).toBeDefined(); - const [, receivedContext] = receivedCall as [ - string, - Record, - ]; - expect(receivedContext).toMatchObject({ + const [, context] = receivedCall as [string, Record]; + expect(context).toMatchObject({ + correlationId: "some-idempotency-key", messageId: "test-msg-789", callbackType: "MessageStatus", - signature: "", + signature: "test-sig", }); + expect(context).toHaveProperty("payload"); }); }); }); diff --git a/lambdas/mock-webhook-lambda/src/index.ts b/lambdas/mock-webhook-lambda/src/index.ts index feddddb6..8104e31a 100644 --- a/lambdas/mock-webhook-lambda/src/index.ts +++ b/lambdas/mock-webhook-lambda/src/index.ts @@ -36,15 +36,25 @@ function isClientCallbackPayload( async function buildResponse( event: APIGatewayProxyEvent, ): Promise { - logger.info("Mock webhook invoked", { - path: event.path, - method: event.httpMethod, - }); - + const eventWithFunctionUrlFields = event as APIGatewayProxyEvent & { + rawPath?: string; + requestContext?: { http?: { method?: string } }; + }; const headers = Object.fromEntries( Object.entries(event.headers).map(([k, v]) => [k.toLowerCase(), v]), ) as Record; + const path = event.path ?? eventWithFunctionUrlFields.rawPath; + + logger.info("Mock webhook invoked", { + path, + method: event.httpMethod, + hasBody: Boolean(event.body), + "x-api-key": headers["x-api-key"], + "x-hmac-sha256-signature": headers["x-hmac-sha256-signature"], + payload: event.body, + }); + const expectedApiKey = process.env.API_KEY; const providedApiKey = headers["x-api-key"]; @@ -68,6 +78,8 @@ async function buildResponse( try { const parsed = JSON.parse(event.body) as unknown; + logger.info("Mock webhook parsed payload", { parsedPayload: parsed }); + if (!isClientCallbackPayload(parsed)) { logger.error("Invalid message structure - missing or invalid data array"); @@ -107,18 +119,12 @@ async function buildResponse( }; } - logger.info( - `CALLBACK ${correlationId} ${item.type} : ${JSON.stringify(item)}`, - { - correlationId, - messageId, - messageType: item.type, - }, - ); - logger.info("Callback received", { + correlationId, messageId, callbackType: item.type, + path, + apiKey: providedApiKey, signature: headers["x-hmac-sha256-signature"] ?? "", payload: JSON.stringify(item), }); diff --git a/package.json b/package.json index 68238b47..d5b58eef 100644 --- a/package.json +++ b/package.json @@ -52,6 +52,8 @@ "test:unit:silent": "LOG_LEVEL=silent npm run test:unit --workspaces", "typecheck": "npm run typecheck --workspaces", "verify": "npm run lint && npm run typecheck && npm run test:unit", + "applications-map:add": "npm run --silent applications-map-add --workspace tools/client-subscriptions-management --", + "applications-map:get": "npm run --silent applications-map-get --workspace tools/client-subscriptions-management --", "clients:list": "npm run --silent clients-list --workspace tools/client-subscriptions-management --", "clients:get": "npm run --silent clients-get --workspace tools/client-subscriptions-management --", "clients:put": "npm run --silent clients-put --workspace tools/client-subscriptions-management --", diff --git a/scripts/tests/integration-env.sh b/scripts/tests/integration-env.sh new file mode 100644 index 00000000..4d5eb014 --- /dev/null +++ b/scripts/tests/integration-env.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +set -euo pipefail + +: "${ENVIRONMENT:?ENVIRONMENT must be set}" +: "${AWS_REGION:?AWS_REGION must be set}" + +# Add new clients here: "fixture-filename.json:ENV_VAR_PREFIX" +CLIENTS=( + "mock-client-subscription-1.json:MOCK_CLIENT" + "mock-client-subscription-2.json:MOCK_CLIENT_2" +) + +for CLIENT_ENTRY in "${CLIENTS[@]}"; do + FIXTURE="${CLIENT_ENTRY%%:*}" + PREFIX="${CLIENT_ENTRY##*:}" + + SEED_CONFIG_FILE="$(pwd)/tests/integration/fixtures/${FIXTURE}" + CLIENT_ID=$(jq -r '.clientId' "${SEED_CONFIG_FILE}") + + echo "Retrieving client config for ${CLIENT_ID}" + CLIENT_CONFIG=$(npm run --silent clients:get -- \ + --client-id "${CLIENT_ID}" \ + --environment "${ENVIRONMENT}" \ + --region "${AWS_REGION}") + + echo "Retrieving application ID for ${CLIENT_ID}" + APPLICATION_ID=$(npm run --silent applications-map:get -- \ + --client-id "${CLIENT_ID}" \ + --environment "${ENVIRONMENT}" \ + --region "${AWS_REGION}") + + export "${PREFIX}_API_KEY=$(echo "${CLIENT_CONFIG}" | jq -r '.targets[0].apiKey.headerValue')" + export "${PREFIX}_APPLICATION_ID=${APPLICATION_ID}" +done diff --git a/scripts/tests/integration.sh b/scripts/tests/integration.sh index 8460d02b..fcc89389 100755 --- a/scripts/tests/integration.sh +++ b/scripts/tests/integration.sh @@ -5,4 +5,7 @@ set -euo pipefail cd "$(git rev-parse --show-toplevel)" npm ci + +source ./scripts/tests/integration-env.sh + npm run test:integration diff --git a/tests/integration/dlq-redrive.test.ts b/tests/integration/dlq-redrive.test.ts index f0406004..0645caa2 100644 --- a/tests/integration/dlq-redrive.test.ts +++ b/tests/integration/dlq-redrive.test.ts @@ -4,11 +4,12 @@ import type { StatusPublishEvent, } from "@nhs-notify-client-callbacks/models"; import { + assertCallbackHeaders, awaitSignedCallbacksFromWebhookLogGroup, buildInboundEventQueueUrl, buildLambdaLogGroupName, buildMockClientDlqQueueUrl, - computeExpectedSignature, + buildMockWebhookTargetPath, createCloudWatchLogsClient, createMessageStatusPublishEvent, createSqsClient, @@ -50,7 +51,7 @@ describe("DLQ Redrive", () => { }); describe("Infrastructure validation", () => { - it("should confirm the mock-client DLQ is accessible", async () => { + it("should confirm the target DLQ is accessible", async () => { const response = await sqsClient.send( new GetQueueAttributesCommand({ QueueUrl: dlqQueueUrl, @@ -93,6 +94,7 @@ describe("DLQ Redrive", () => { webhookLogGroupName, event.data.messageId, "MessageStatus", + buildMockWebhookTargetPath(), startTime, ); @@ -103,9 +105,7 @@ describe("DLQ Redrive", () => { messageStatus: "delivered", }), }); - expect(callbacks[0].headers["x-hmac-sha256-signature"]).toBe( - computeExpectedSignature(callbacks[0].payload), - ); + assertCallbackHeaders(callbacks[0]); }, 120_000); it("should apply the same transformation logic to redriven events as original deliveries", async () => { @@ -142,21 +142,24 @@ describe("DLQ Redrive", () => { expect(dlqPayload.data.messageId).toBe(redriveEvent.data.messageId); - const directCallbacks = await awaitSignedCallbacksFromWebhookLogGroup( - cloudWatchClient, - webhookLogGroupName, - directEvent.data.messageId, - "MessageStatus", - startTime, - ); - - const redriveCallbacks = await awaitSignedCallbacksFromWebhookLogGroup( - cloudWatchClient, - webhookLogGroupName, - redriveEvent.data.messageId, - "MessageStatus", - startTime, - ); + const [directCallbacks, redriveCallbacks] = await Promise.all([ + awaitSignedCallbacksFromWebhookLogGroup( + cloudWatchClient, + webhookLogGroupName, + directEvent.data.messageId, + "MessageStatus", + buildMockWebhookTargetPath(), + startTime, + ), + awaitSignedCallbacksFromWebhookLogGroup( + cloudWatchClient, + webhookLogGroupName, + redriveEvent.data.messageId, + "MessageStatus", + buildMockWebhookTargetPath(), + startTime, + ), + ]); await ensureInboundQueueIsEmpty(sqsClient, inboundQueueUrl); @@ -168,9 +171,7 @@ describe("DLQ Redrive", () => { ).messageStatus, }), }); - expect(redriveCallbacks[0].headers["x-hmac-sha256-signature"]).toBe( - computeExpectedSignature(redriveCallbacks[0].payload), - ); + assertCallbackHeaders(redriveCallbacks[0]); }, 120_000); }); }); diff --git a/tests/integration/fixtures/mock-client-subscription-1.json b/tests/integration/fixtures/mock-client-subscription-1.json new file mode 100644 index 00000000..1e76ad65 --- /dev/null +++ b/tests/integration/fixtures/mock-client-subscription-1.json @@ -0,0 +1,45 @@ +{ + "clientId": "mock-client-1", + "subscriptions": [ + { + "messageStatuses": [ + "DELIVERED", + "FAILED" + ], + "subscriptionId": "sub-28fc741d-de6c-41a9-8fb0-89c4115c7dcf", + "subscriptionType": "MessageStatus", + "targetIds": [ + "target-23b2ee2f-8e81-43cd-9bb8-5ea30a09f779" + ] + }, + { + "channelStatuses": [ + "DELIVERED", + "FAILED" + ], + "channelType": "NHSAPP", + "subscriptionId": "sub-0fa7076d-5640-47ea-9f8f-70ff20778571", + "subscriptionType": "ChannelStatus", + "supplierStatuses": [ + "delivered", + "permanent_failure" + ], + "targetIds": [ + "target-23b2ee2f-8e81-43cd-9bb8-5ea30a09f779" + ] + } + ], + "targets": [ + { + "apiKey": { + "headerName": "x-api-key", + "headerValue": "REPLACED_BY_TERRAFORM" + }, + "invocationEndpoint": "https://REPLACED_BY_TERRAFORM", + "invocationMethod": "POST", + "invocationRateLimit": 10, + "targetId": "target-23b2ee2f-8e81-43cd-9bb8-5ea30a09f779", + "type": "API" + } + ] +} diff --git a/tests/integration/fixtures/mock-client-subscription-2.json b/tests/integration/fixtures/mock-client-subscription-2.json new file mode 100644 index 00000000..ee7091cd --- /dev/null +++ b/tests/integration/fixtures/mock-client-subscription-2.json @@ -0,0 +1,41 @@ +{ + "clientId": "mock-client-2", + "subscriptions": [ + { + "messageStatuses": [ + "DELIVERED", + "FAILED" + ], + "subscriptionId": "sub-8afc2d3f-0645-4d35-9f3a-508699602f4b", + "subscriptionType": "MessageStatus", + "targetIds": [ + "target-1f3aa57d-c0b6-4a0a-a8e9-c7f97f1e27e7", + "target-c23f4ad8-2b6f-4510-b5b6-40f2b7fbbec5" + ] + } + ], + "targets": [ + { + "apiKey": { + "headerName": "x-api-key", + "headerValue": "REPLACED_BY_TERRAFORM" + }, + "invocationEndpoint": "https://REPLACED_BY_TERRAFORM", + "invocationMethod": "POST", + "invocationRateLimit": 10, + "targetId": "target-1f3aa57d-c0b6-4a0a-a8e9-c7f97f1e27e7", + "type": "API" + }, + { + "apiKey": { + "headerName": "x-api-key", + "headerValue": "REPLACED_BY_TERRAFORM" + }, + "invocationEndpoint": "https://REPLACED_BY_TERRAFORM", + "invocationMethod": "POST", + "invocationRateLimit": 10, + "targetId": "target-c23f4ad8-2b6f-4510-b5b6-40f2b7fbbec5", + "type": "API" + } + ] +} diff --git a/tests/integration/helpers/cloudwatch.ts b/tests/integration/helpers/cloudwatch.ts index 84559f1a..c57ce48d 100644 --- a/tests/integration/helpers/cloudwatch.ts +++ b/tests/integration/helpers/cloudwatch.ts @@ -9,19 +9,28 @@ import { TimeoutError, waitUntil } from "async-wait-until"; const CALLBACK_WAIT_TIMEOUT_MS = 60_000; const METRICS_WAIT_TIMEOUT_MS = 60_000; const POLL_INTERVAL_MS = 2000; +const CLOUDWATCH_QUERY_LOOKBACK_MS = Number( + process.env.CLOUDWATCH_QUERY_LOOKBACK_MS ?? 5000, +); type LogEntry = { msg: string; correlationId?: string; callbackType?: string; clientId?: string; + apiKey?: string; signature?: string; payload?: string; + path?: string; }; export type SignedCallback = { payload: CallbackItem; - headers: { "x-hmac-sha256-signature": string }; + path: string; + headers: { + "x-api-key": string; + "x-hmac-sha256-signature": string; + }; }; async function querySignedCallbacksFromWebhookLogGroup( @@ -29,14 +38,12 @@ async function querySignedCallbacksFromWebhookLogGroup( logGroupName: string, messageId: string, callbackType: CallbackItem["type"], - startTime: number, ): Promise { const filterPattern = `{ $.msg = "Callback received" && $.messageId = "${messageId}" && $.callbackType = "${callbackType}" }`; const response = await client.send( new FilterLogEventsCommand({ logGroupName, - startTime, filterPattern, }), ); @@ -51,7 +58,11 @@ async function querySignedCallbacksFromWebhookLogGroup( if (entry.signature !== undefined && entry.payload) { callbacks.push({ payload: JSON.parse(entry.payload) as CallbackItem, - headers: { "x-hmac-sha256-signature": entry.signature }, + path: entry.path ?? "", + headers: { + "x-api-key": entry.apiKey ?? "", + "x-hmac-sha256-signature": entry.signature, + }, }); } } catch { @@ -94,24 +105,37 @@ export async function awaitSignedCallbacksFromWebhookLogGroup( logGroupName: string, messageId: string, callbackType: CallbackItem["type"], - startTime: number, + path: string, ): Promise { logger.debug( - `Waiting for callback in webhook CloudWatch log group (messageId=${messageId}, logGroup=${logGroupName})`, + `Waiting for callback in webhook CloudWatch log group (messageId=${messageId}, path=${path}, logGroup=${logGroupName})`, ); - return pollUntilFound( + const callbacks = await pollUntilFound( () => querySignedCallbacksFromWebhookLogGroup( client, logGroupName, messageId, callbackType, - startTime, ), CALLBACK_WAIT_TIMEOUT_MS, - `Timed out waiting for callback in webhook CloudWatch log group (messageId=${messageId}, callbackType=${callbackType}, timeoutMs=${CALLBACK_WAIT_TIMEOUT_MS})`, + `Timed out waiting for callback in webhook CloudWatch log group (messageId=${messageId}, callbackType=${callbackType}, path=${path}, timeoutMs=${CALLBACK_WAIT_TIMEOUT_MS})`, ); + + if (callbacks.length !== 1) { + throw new Error( + `Expected exactly 1 callback for messageId="${messageId}" callbackType="${callbackType}", but found ${callbacks.length}`, + ); + } + + if (callbacks[0].path !== path) { + throw new Error( + `Expected callback path "${path}" for messageId="${messageId}", but got "${callbacks[0].path}"`, + ); + } + + return callbacks; } type EmfEntry = Record; @@ -139,13 +163,14 @@ async function queryEmfMetricsFromLogGroup( metricNames: string[], startTime: number, ): Promise> { + const queryStartTime = Math.max(0, startTime - CLOUDWATCH_QUERY_LOOKBACK_MS); const conditions = metricNames.map((name) => `$.${name} > 0`).join(" || "); const filterPattern = `{ ${conditions} }`; const response = await client.send( new FilterLogEventsCommand({ logGroupName, - startTime, + startTime: queryStartTime, filterPattern, }), ); @@ -165,8 +190,11 @@ export async function awaitAllEmfMetricsInLogGroup( metricNames: string[], startTime: number, ): Promise { + const queryStartTime = Math.max(0, startTime - CLOUDWATCH_QUERY_LOOKBACK_MS); + const queryStartTimeIso = new Date(queryStartTime).toISOString(); + const startTimeIso = new Date(startTime).toISOString(); logger.debug( - `Waiting for EMF metrics in CloudWatch log group (metrics=${metricNames.join(",")}, logGroup=${logGroupName})`, + `Waiting for EMF metrics in CloudWatch log group (metrics=${metricNames.join(",")}, logGroup=${logGroupName}, startTimeIso=${startTimeIso}, queryStartTimeIso=${queryStartTimeIso}, lookbackMs=${CLOUDWATCH_QUERY_LOOKBACK_MS})`, ); await waitUntil( diff --git a/tests/integration/helpers/event-factories.ts b/tests/integration/helpers/event-factories.ts index 3e25ad3f..015bbced 100644 --- a/tests/integration/helpers/event-factories.ts +++ b/tests/integration/helpers/event-factories.ts @@ -5,6 +5,8 @@ import type { } from "@nhs-notify-client-callbacks/models"; import { EventTypes } from "@nhs-notify-client-callbacks/models"; +import { getMockItClientConfig } from "./mock-client-config"; + type MessageEventOverrides = { event?: Partial>; data?: Partial; @@ -23,7 +25,7 @@ export function createMessageStatusPublishEvent( overrides?.data?.messageReference ?? `ref-${crypto.randomUUID()}`; const baseData: MessageStatusData = { - clientId: "mock-client", + clientId: getMockItClientConfig().clientId, messageId, messageReference, messageStatus: "DELIVERED", @@ -78,7 +80,7 @@ export function createChannelStatusPublishEvent( overrides?.data?.messageReference ?? `ref-${crypto.randomUUID()}`; const baseData: ChannelStatusData = { - clientId: "mock-client", + clientId: getMockItClientConfig().clientId, messageId, messageReference, channel: "NHSAPP", diff --git a/tests/integration/helpers/index.ts b/tests/integration/helpers/index.ts index 7f021566..7c5c70b8 100644 --- a/tests/integration/helpers/index.ts +++ b/tests/integration/helpers/index.ts @@ -3,6 +3,7 @@ export * from "./clients"; export * from "./sqs"; export * from "./cloudwatch"; export { default as sendEventToDlqAndRedrive } from "./redrive"; +export * from "./mock-client-config"; export * from "./status-events"; export * from "./event-factories"; export * from "./signature"; diff --git a/tests/integration/helpers/mock-client-config.ts b/tests/integration/helpers/mock-client-config.ts new file mode 100644 index 00000000..6e484f00 --- /dev/null +++ b/tests/integration/helpers/mock-client-config.ts @@ -0,0 +1,49 @@ +import { readFileSync } from "node:fs"; +import path from "node:path"; +import type seedConfigJson from "../fixtures/mock-client-subscription-1.json"; + +type ClientFixtureShape = typeof seedConfigJson; + +export type MockItClientConfig = ClientFixtureShape & { + apiKeyVar: string; + applicationIdVar: string; +}; + +export const CLIENT_FIXTURES = { + client1: { + fixture: "mock-client-subscription-1.json", + apiKeyVar: "MOCK_CLIENT_API_KEY", + applicationIdVar: "MOCK_CLIENT_APPLICATION_ID", + }, + client2: { + fixture: "mock-client-subscription-2.json", + apiKeyVar: "MOCK_CLIENT_2_API_KEY", + applicationIdVar: "MOCK_CLIENT_2_APPLICATION_ID", + }, +} as const; + +export type ClientFixtureKey = keyof typeof CLIENT_FIXTURES; + +export function getClientConfig(key: ClientFixtureKey): MockItClientConfig { + // eslint-disable-next-line security/detect-object-injection -- key is constrained to ClientFixtureKey, a keyof the hardcoded as-const CLIENT_FIXTURES object + const { apiKeyVar, applicationIdVar, fixture } = CLIENT_FIXTURES[key]; + const resolved = path.resolve(__dirname, "..", "fixtures", fixture); + // eslint-disable-next-line security/detect-non-literal-fs-filename -- path is constructed from __dirname and a basename sourced from the hardcoded as-const CLIENT_FIXTURES registry + const data = JSON.parse(readFileSync(resolved, "utf8")) as ClientFixtureShape; + return { ...data, apiKeyVar, applicationIdVar }; +} + +export function getMockItClientConfig(): MockItClientConfig { + return getClientConfig("client1"); +} + +export function getMockItClient2Config(): MockItClientConfig { + return getClientConfig("client2"); +} + +export function buildMockWebhookTargetPath( + key: ClientFixtureKey = "client1", +): string { + const config = getClientConfig(key); + return `/${config.targets[0].targetId}`; +} diff --git a/tests/integration/helpers/signature.ts b/tests/integration/helpers/signature.ts index 2ed4d191..47780ea4 100644 --- a/tests/integration/helpers/signature.ts +++ b/tests/integration/helpers/signature.ts @@ -1,13 +1,43 @@ import { createHmac } from "node:crypto"; import type { CallbackItem } from "@nhs-notify-client-callbacks/models"; +import type { SignedCallback } from "./cloudwatch"; -const MOCK_HMAC_SECRET = "mock-application-id.some-api-key"; +function resolveEnvVar(name: string): string { + const result = process.env[name]; + if (result) { + return result; + } + throw new Error(`Missing ${name} for integration signature verification`); +} + +function resolveSigningSecret( + apiKeyVar: string, + applicationIdVar: string, +): string { + return `${resolveEnvVar(applicationIdVar)}.${resolveEnvVar(apiKeyVar)}`; +} -export function computeExpectedSignature(payload: CallbackItem): string { - // eslint-disable-next-line sonarjs/hardcoded-secret-signatures - return createHmac("sha256", MOCK_HMAC_SECRET) +export function computeExpectedSignature( + payload: CallbackItem, + apiKeyVar = "MOCK_CLIENT_API_KEY", + applicationIdVar = "MOCK_CLIENT_APPLICATION_ID", +): string { + const signingSecret = resolveSigningSecret(apiKeyVar, applicationIdVar); + return createHmac("sha256", signingSecret) .update(JSON.stringify({ data: [payload] })) .digest("hex"); } +export function assertCallbackHeaders( + callback: SignedCallback, + apiKeyVar = "MOCK_CLIENT_API_KEY", + applicationIdVar = "MOCK_CLIENT_APPLICATION_ID", +): void { + expect(callback.headers["x-api-key"]).toBeDefined(); + expect(callback.headers["x-api-key"]).toBe(resolveEnvVar(apiKeyVar)); + expect(callback.headers["x-hmac-sha256-signature"]).toBe( + computeExpectedSignature(callback.payload, apiKeyVar, applicationIdVar), + ); +} + export default computeExpectedSignature; diff --git a/tests/integration/helpers/sqs.ts b/tests/integration/helpers/sqs.ts index 49f70e78..47264139 100644 --- a/tests/integration/helpers/sqs.ts +++ b/tests/integration/helpers/sqs.ts @@ -12,6 +12,7 @@ import { logger } from "@nhs-notify-client-callbacks/logger"; import { waitUntil } from "async-wait-until"; import type { DeploymentDetails } from "./deployment"; +import { getMockItClientConfig } from "./mock-client-config"; const QUEUE_WAIT_TIMEOUT_MS = 60_000; const POLL_INTERVAL_MS = 500; @@ -62,7 +63,8 @@ export function buildInboundEventDlqQueueUrl( export function buildMockClientDlqQueueUrl( deploymentDetails: DeploymentDetails, ): string { - return buildQueueUrl(deploymentDetails, "mock-client-dlq"); + const { targets } = getMockItClientConfig(); + return buildQueueUrl(deploymentDetails, `${targets[0].targetId}-dlq`); } export async function sendSqsEvent( diff --git a/tests/integration/helpers/status-events.ts b/tests/integration/helpers/status-events.ts index 3d114ceb..1bb763f3 100644 --- a/tests/integration/helpers/status-events.ts +++ b/tests/integration/helpers/status-events.ts @@ -21,9 +21,8 @@ async function processStatusEvent< webhookLogGroupName: string, event: StatusPublishEvent, callbackType: SignedCallback["payload"]["type"], + webhookPath: string, ): Promise { - const startTime = Date.now(); - const sendMessageResponse = await sendSqsEvent( sqsClient, callbackEventQueueUrl, @@ -41,7 +40,7 @@ async function processStatusEvent< webhookLogGroupName, event.data.messageId, callbackType, - startTime, + webhookPath, ); } @@ -51,6 +50,7 @@ export async function processMessageStatusEvent( callbackEventQueueUrl: string, webhookLogGroupName: string, messageStatusEvent: StatusPublishEvent, + webhookPath: string, ): Promise { return processStatusEvent( sqsClient, @@ -59,6 +59,7 @@ export async function processMessageStatusEvent( webhookLogGroupName, messageStatusEvent, "MessageStatus", + webhookPath, ); } @@ -68,6 +69,7 @@ export async function processChannelStatusEvent( callbackEventQueueUrl: string, webhookLogGroupName: string, channelStatusEvent: StatusPublishEvent, + webhookPath: string, ): Promise { return processStatusEvent( sqsClient, @@ -76,5 +78,6 @@ export async function processChannelStatusEvent( webhookLogGroupName, channelStatusEvent, "ChannelStatus", + webhookPath, ); } diff --git a/tests/integration/event-bus-to-webhook.test.ts b/tests/integration/inbound-sqs-to-webhook.test.ts similarity index 94% rename from tests/integration/event-bus-to-webhook.test.ts rename to tests/integration/inbound-sqs-to-webhook.test.ts index 0af0076c..0df2e400 100644 --- a/tests/integration/event-bus-to-webhook.test.ts +++ b/tests/integration/inbound-sqs-to-webhook.test.ts @@ -5,13 +5,14 @@ import { type StatusPublishEvent, } from "@nhs-notify-client-callbacks/models"; import { + assertCallbackHeaders, awaitQueueMessage, awaitQueueMessageByMessageId, buildInboundEventDlqQueueUrl, buildInboundEventQueueUrl, buildLambdaLogGroupName, buildMockClientDlqQueueUrl, - computeExpectedSignature, + buildMockWebhookTargetPath, createChannelStatusPublishEvent, createCloudWatchLogsClient, createMessageStatusPublishEvent, @@ -32,6 +33,7 @@ describe("SQS to Webhook Integration", () => { let clientDlqQueueUrl: string; let inboundEventDlqQueueUrl: string; let webhookLogGroupName: string; + let webhookTargetPath: string; beforeAll(async () => { const deploymentDetails = getDeploymentDetails(); @@ -45,6 +47,7 @@ describe("SQS to Webhook Integration", () => { deploymentDetails, "mock-webhook", ); + webhookTargetPath = buildMockWebhookTargetPath(); await purgeQueues(sqsClient, [ inboundEventDlqQueueUrl, @@ -75,6 +78,7 @@ describe("SQS to Webhook Integration", () => { callbackEventQueueUrl, webhookLogGroupName, messageStatusEvent, + webhookTargetPath, ); expect(callbacks).toHaveLength(1); @@ -87,9 +91,7 @@ describe("SQS to Webhook Integration", () => { }), }); - expect(callbacks[0].headers["x-hmac-sha256-signature"]).toBe( - computeExpectedSignature(callbacks[0].payload), - ); + assertCallbackHeaders(callbacks[0]); }, 120_000); }); @@ -104,6 +106,7 @@ describe("SQS to Webhook Integration", () => { callbackEventQueueUrl, webhookLogGroupName, channelStatusEvent, + webhookTargetPath, ); expect(callbacks).toHaveLength(1); @@ -118,14 +121,12 @@ describe("SQS to Webhook Integration", () => { }), }); - expect(callbacks[0].headers["x-hmac-sha256-signature"]).toBe( - computeExpectedSignature(callbacks[0].payload), - ); + assertCallbackHeaders(callbacks[0]); }, 120_000); }); describe("Client Webhook DLQ", () => { - it("should route a non-retriable (4xx) webhook response to the per-client DLQ", async () => { + it("should route a non-retriable (4xx) webhook response to the per-target DLQ", async () => { const event: StatusPublishEvent = createMessageStatusPublishEvent({ data: { diff --git a/tests/integration/jest.config.ts b/tests/integration/jest.config.ts index fd9a3fa4..c4c673ed 100644 --- a/tests/integration/jest.config.ts +++ b/tests/integration/jest.config.ts @@ -3,8 +3,6 @@ import { nodeJestConfig } from "../../jest.config.base"; export default { ...nodeJestConfig, modulePaths: [""], - globalSetup: "/jest.global-setup.ts", - globalTeardown: "/jest.global-teardown.ts", coveragePathIgnorePatterns: [ ...(nodeJestConfig.coveragePathIgnorePatterns ?? []), "/helpers/", diff --git a/tests/integration/jest.global-setup.ts b/tests/integration/jest.global-setup.ts deleted file mode 100644 index 3f4d58bd..00000000 --- a/tests/integration/jest.global-setup.ts +++ /dev/null @@ -1,60 +0,0 @@ -import { PutObjectCommand } from "@aws-sdk/client-s3"; -import { - buildSubscriptionConfigBucketName, - createS3Client, - getDeploymentDetails, -} from "./helpers"; - -const mockClientSubscriptionKey = "client_subscriptions/mock-client.json"; - -const mockClientSubscriptionBody = JSON.stringify({ - clientId: "mock-client", - subscriptions: [ - { - subscriptionId: "mock-client-message", - subscriptionType: "MessageStatus", - messageStatuses: ["DELIVERED"], - targetIds: ["445527ff-277b-43a4-a4b0-15eedbd71597"], - }, - { - subscriptionId: "mock-client-channel", - subscriptionType: "ChannelStatus", - channelStatuses: ["DELIVERED"], - channelType: "NHSAPP", - supplierStatuses: ["delivered"], - targetIds: ["445527ff-277b-43a4-a4b0-15eedbd71597"], - }, - ], - targets: [ - { - type: "API", - targetId: "445527ff-277b-43a4-a4b0-15eedbd71597", - invocationEndpoint: "https://some-mock-client.endpoint/webhook", - invocationMethod: "POST", - invocationRateLimit: 10, - apiKey: { - headerName: "x-api-key", - headerValue: "some-api-key", - }, - }, - ], -}); - -export default async function globalSetup() { - const deploymentDetails = getDeploymentDetails(); - const bucketName = buildSubscriptionConfigBucketName(deploymentDetails); - const client = createS3Client(deploymentDetails); - - try { - await client.send( - new PutObjectCommand({ - Bucket: bucketName, - Key: mockClientSubscriptionKey, - ContentType: "application/json", - Body: mockClientSubscriptionBody, - }), - ); - } finally { - client.destroy(); - } -} diff --git a/tests/integration/jest.global-teardown.ts b/tests/integration/jest.global-teardown.ts deleted file mode 100644 index 192c1892..00000000 --- a/tests/integration/jest.global-teardown.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { DeleteObjectCommand } from "@aws-sdk/client-s3"; -import { - buildSubscriptionConfigBucketName, - createS3Client, - getDeploymentDetails, -} from "./helpers"; - -const mockClientSubscriptionKey = "client_subscriptions/mock-client.json"; - -export default async function globalTeardown() { - const deploymentDetails = getDeploymentDetails(); - const bucketName = buildSubscriptionConfigBucketName(deploymentDetails); - const client = createS3Client(deploymentDetails); - - try { - await client.send( - new DeleteObjectCommand({ - Bucket: bucketName, - Key: mockClientSubscriptionKey, - }), - ); - } finally { - client.destroy(); - } -} diff --git a/tests/integration/metrics.test.ts b/tests/integration/metrics.test.ts index 19a41013..804556d4 100644 --- a/tests/integration/metrics.test.ts +++ b/tests/integration/metrics.test.ts @@ -11,6 +11,7 @@ import { buildInboundEventQueueUrl, buildLambdaLogGroupName, buildMockClientDlqQueueUrl, + buildMockWebhookTargetPath, createCloudWatchLogsClient, createMessageStatusPublishEvent, createSqsClient, @@ -79,6 +80,7 @@ describe("Metrics", () => { webhookLogGroupName, event.data.messageId, "MessageStatus", + buildMockWebhookTargetPath(), startTime, ); diff --git a/tests/integration/tsconfig.json b/tests/integration/tsconfig.json index a5cc2b81..01538bd3 100644 --- a/tests/integration/tsconfig.json +++ b/tests/integration/tsconfig.json @@ -6,7 +6,8 @@ "helpers": [ "./helpers/index" ] - } + }, + "resolveJsonModule": true }, "extends": "../../tsconfig.base.json", "include": [ diff --git a/tools/client-subscriptions-management/package.json b/tools/client-subscriptions-management/package.json index b675916e..90dfcc87 100644 --- a/tools/client-subscriptions-management/package.json +++ b/tools/client-subscriptions-management/package.json @@ -13,6 +13,8 @@ "targets-list": "tsx ./src/entrypoint/cli/index.ts targets-list", "targets-add": "tsx ./src/entrypoint/cli/index.ts targets-add", "targets-del": "tsx ./src/entrypoint/cli/index.ts targets-del", + "applications-map-add": "tsx ./src/entrypoint/cli/index.ts applications-map-add", + "applications-map-get": "tsx ./src/entrypoint/cli/index.ts applications-map-get", "lint": "eslint .", "lint:fix": "eslint . --fix", "test:unit": "jest", @@ -20,6 +22,7 @@ }, "dependencies": { "@aws-sdk/client-s3": "^3.821.0", + "@aws-sdk/client-ssm": "^3.1011.0", "@aws-sdk/client-sts": "^3.1004.0", "@aws-sdk/credential-providers": "^3.1004.0", "@nhs-notify-client-callbacks/models": "*", diff --git a/tools/client-subscriptions-management/src/__tests__/aws.test.ts b/tools/client-subscriptions-management/src/__tests__/aws.test.ts index f2c3f4e1..f08d0bda 100644 --- a/tools/client-subscriptions-management/src/__tests__/aws.test.ts +++ b/tools/client-subscriptions-management/src/__tests__/aws.test.ts @@ -1,6 +1,8 @@ import { deriveBucketName, + deriveParameterName, resolveBucketName, + resolveParameterName, resolveProfile, resolveRegion, } from "src/aws"; @@ -78,4 +80,38 @@ describe("aws", () => { it("returns undefined when region is not set", () => { expect(resolveRegion(undefined, {} as NodeJS.ProcessEnv)).toBeUndefined(); }); + + it("derives parameter name from environment", () => { + expect(deriveParameterName("dev")).toBe( + "/nhs/dev/callbacks/applications-map", + ); + }); + + it("resolves parameter name from explicit argument", () => { + expect(resolveParameterName({ parameterName: "/custom/path" })).toBe( + "/custom/path", + ); + }); + + it("derives parameter name from environment argument", () => { + expect(resolveParameterName({ environment: "dev" })).toBe( + "/nhs/dev/callbacks/applications-map", + ); + }); + + it("derives parameter name from ENVIRONMENT env var", () => { + expect( + resolveParameterName({ + env: { ENVIRONMENT: "staging" } as NodeJS.ProcessEnv, + }), + ).toBe("/nhs/staging/callbacks/applications-map"); + }); + + it("throws when no parameter name can be resolved", () => { + expect(() => + resolveParameterName({ env: {} as NodeJS.ProcessEnv }), + ).toThrow( + "Environment is required to derive parameter name. Please provide via --environment or ENVIRONMENT env var.", + ); + }); }); diff --git a/tools/client-subscriptions-management/src/__tests__/entrypoint/cli/applications-map-add.test.ts b/tools/client-subscriptions-management/src/__tests__/entrypoint/cli/applications-map-add.test.ts new file mode 100644 index 00000000..99b08ca9 --- /dev/null +++ b/tools/client-subscriptions-management/src/__tests__/entrypoint/cli/applications-map-add.test.ts @@ -0,0 +1,109 @@ +import * as cli from "src/entrypoint/cli/applications-map-add"; +import * as helper from "src/entrypoint/cli/helper"; +import { + captureCliConsoleState, + expectWrappedCliError, + resetCliConsoleState, + restoreCliConsoleState, +} from "src/__tests__/entrypoint/cli/test-utils"; + +const mockAddApplication = jest.fn(); +const mockFormatApplicationsMap = jest.fn(); + +jest.mock("src/entrypoint/cli/helper", () => ({ + ...jest.requireActual("src/entrypoint/cli/helper"), + createSsmApplicationsMapRepository: jest.fn(), +})); + +jest.mock("src/format", () => ({ + ...jest.requireActual("src/format"), + formatApplicationsMap: (...args: unknown[]) => + mockFormatApplicationsMap(...args), +})); + +const mockCreateSsmApplicationsMapRepository = + helper.createSsmApplicationsMapRepository as jest.Mock; + +describe("applications-map-add CLI", () => { + const originalCliConsoleState = captureCliConsoleState(); + + const baseArgs = [ + "node", + "script", + "--client-id", + "client-1", + "--application-id", + "app-1", + "--parameter-name", + "/nhs/dev/callbacks/applications-map", + ]; + + const resultMap = new Map([["client-1", "app-1"]]); + + beforeEach(() => { + mockAddApplication.mockReset(); + mockAddApplication.mockResolvedValue(resultMap); + mockFormatApplicationsMap.mockReset(); + mockFormatApplicationsMap.mockReturnValue("masked-map-output"); + mockCreateSsmApplicationsMapRepository.mockReset(); + mockCreateSsmApplicationsMapRepository.mockReturnValue({ + addApplication: mockAddApplication, + }); + resetCliConsoleState(); + }); + + afterAll(() => { + restoreCliConsoleState(originalCliConsoleState); + }); + + it("adds application and logs output", async () => { + await cli.main(baseArgs); + + expect(mockCreateSsmApplicationsMapRepository).toHaveBeenCalledWith( + expect.objectContaining({ + "client-id": "client-1", + "application-id": "app-1", + "parameter-name": "/nhs/dev/callbacks/applications-map", + }), + ); + expect(mockAddApplication).toHaveBeenCalledWith("client-1", "app-1", false); + expect(console.log).toHaveBeenCalledWith( + "Applications map updated for client 'client-1'.", + ); + expect(mockFormatApplicationsMap).toHaveBeenCalledWith(resultMap); + expect(console.log).toHaveBeenCalledWith("masked-map-output"); + }); + + it("does not log application-id", async () => { + await cli.main(baseArgs); + + const logMessages = (console.log as jest.Mock).mock.calls.flat(); + expect(logMessages).not.toContain("app-1"); + }); + + it("does not log dry-run message when dry-run is false", async () => { + await cli.main(baseArgs); + + expect(console.log).not.toHaveBeenCalledWith( + "Dry run — no changes written to SSM.", + ); + }); + + it("passes dry-run flag to repository and logs dry-run message", async () => { + await cli.main([...baseArgs, "--dry-run"]); + + expect(mockAddApplication).toHaveBeenCalledWith("client-1", "app-1", true); + expect(console.log).toHaveBeenCalledWith( + "Dry run — no changes written to SSM.", + ); + }); + + it("handles errors in wrapped CLI", async () => { + expect.hasAssertions(); + mockCreateSsmApplicationsMapRepository.mockReturnValue({ + addApplication: jest.fn().mockRejectedValue(new Error("Boom")), + }); + + await expectWrappedCliError(cli.main, baseArgs); + }); +}); diff --git a/tools/client-subscriptions-management/src/__tests__/entrypoint/cli/applications-map-get.test.ts b/tools/client-subscriptions-management/src/__tests__/entrypoint/cli/applications-map-get.test.ts new file mode 100644 index 00000000..8d44efeb --- /dev/null +++ b/tools/client-subscriptions-management/src/__tests__/entrypoint/cli/applications-map-get.test.ts @@ -0,0 +1,84 @@ +import * as cli from "src/entrypoint/cli/applications-map-get"; +import * as helper from "src/entrypoint/cli/helper"; +import { + captureCliConsoleState, + expectWrappedCliError, + resetCliConsoleState, + restoreCliConsoleState, +} from "src/__tests__/entrypoint/cli/test-utils"; + +const mockGetApplication = jest.fn(); + +jest.mock("src/entrypoint/cli/helper", () => ({ + ...jest.requireActual("src/entrypoint/cli/helper"), + createSsmApplicationsMapRepository: jest.fn(), +})); + +const mockCreateSsmApplicationsMapRepository = + helper.createSsmApplicationsMapRepository as jest.Mock; + +describe("applications-map-get CLI", () => { + const originalCliConsoleState = captureCliConsoleState(); + + const baseArgs = [ + "node", + "script", + "--client-id", + "client-1", + "--parameter-name", + "/nhs/dev/callbacks/applications-map", + ]; + + beforeEach(() => { + mockGetApplication.mockReset(); + mockCreateSsmApplicationsMapRepository.mockReset(); + mockCreateSsmApplicationsMapRepository.mockReturnValue({ + getApplication: mockGetApplication, + }); + resetCliConsoleState(); + }); + + afterAll(() => { + restoreCliConsoleState(originalCliConsoleState); + }); + + it("prints the application ID when mapping exists", async () => { + mockGetApplication.mockResolvedValue("app-1"); + + await cli.main(baseArgs); + + expect(mockCreateSsmApplicationsMapRepository).toHaveBeenCalledWith( + expect.objectContaining({ + "client-id": "client-1", + "parameter-name": "/nhs/dev/callbacks/applications-map", + }), + ); + expect(mockGetApplication).toHaveBeenCalledWith("client-1"); + expect(console.log).toHaveBeenCalledWith("app-1"); + }); + + it("does not log the application-id in other messages", async () => { + mockGetApplication.mockResolvedValue("app-1"); + + await cli.main(baseArgs); + + const logMessages = (console.log as jest.Mock).mock.calls.flat(); + expect(logMessages).toEqual(["app-1"]); + }); + + it("throws when no mapping exists for the client", async () => { + mockGetApplication.mockResolvedValue(undefined); + + await expectWrappedCliError( + cli.main, + baseArgs, + "No application mapping exists for client: client-1", + ); + }); + + it("handles repository errors", async () => { + mockGetApplication.mockRejectedValue(new Error("Boom")); + + await expectWrappedCliError(cli.main, baseArgs); + }); +}); diff --git a/tools/client-subscriptions-management/src/__tests__/format.test.ts b/tools/client-subscriptions-management/src/__tests__/format.test.ts index ef0b80cf..a8c83570 100644 --- a/tools/client-subscriptions-management/src/__tests__/format.test.ts +++ b/tools/client-subscriptions-management/src/__tests__/format.test.ts @@ -1,4 +1,5 @@ import { + formatApplicationsMap, formatClientConfig, formatSubscriptionsTable, formatTargetsTable, @@ -73,4 +74,23 @@ describe("format", () => { it("normalizes client name", () => { expect(normalizeClientName("My Client Name")).toBe("my-client-name"); }); + + it("formats empty applications map", () => { + expect(formatApplicationsMap(new Map())).toBe("Applications map: (empty)"); + }); + + it("masks application IDs in applications map output", () => { + const result = formatApplicationsMap( + new Map([ + ["client-a", "app-12345"], + ["client-b", "a"], + ]), + ); + + expect(result).toContain("client-a"); + expect(result).toContain("client-b"); + expect(result).toContain("*********"); + expect(result).toContain("*"); + expect(result).not.toContain("app-12345"); + }); }); diff --git a/tools/client-subscriptions-management/src/__tests__/repository/ssm-applications-map.test.ts b/tools/client-subscriptions-management/src/__tests__/repository/ssm-applications-map.test.ts new file mode 100644 index 00000000..afb94e41 --- /dev/null +++ b/tools/client-subscriptions-management/src/__tests__/repository/ssm-applications-map.test.ts @@ -0,0 +1,157 @@ +import { + GetParameterCommand, + PutParameterCommand, + type SSMClient, +} from "@aws-sdk/client-ssm"; +import SsmApplicationsMapRepository from "src/repository/ssm-applications-map"; + +const createRepository = (send: jest.Mock = jest.fn()) => { + const client = { send } as unknown as SSMClient; + return { + repository: new SsmApplicationsMapRepository(client, "/test/param"), + send, + }; +}; + +describe("SsmApplicationsMapRepository", () => { + describe("getApplication", () => { + it("returns the application ID for an existing client", async () => { + const { repository, send } = createRepository(); + send.mockResolvedValueOnce({ + Parameter: { + Value: JSON.stringify({ "client-1": "app-1", "client-2": "app-2" }), + }, + }); + + const result = await repository.getApplication("client-1"); + + expect(send).toHaveBeenCalledWith(expect.any(GetParameterCommand)); + expect(result).toBe("app-1"); + }); + + it("returns undefined when the client is not in the map", async () => { + const { repository, send } = createRepository(); + send.mockResolvedValueOnce({ + Parameter: { Value: JSON.stringify({ "other-client": "app-1" }) }, + }); + + const result = await repository.getApplication("client-1"); + + expect(result).toBeUndefined(); + }); + + it("returns undefined when parameter does not exist", async () => { + const { repository, send } = createRepository(); + const error = Object.assign(new Error("not found"), { + name: "ParameterNotFound", + }); + send.mockRejectedValueOnce(error); + + const result = await repository.getApplication("client-1"); + + expect(result).toBeUndefined(); + }); + + it("returns undefined when parameter has no value", async () => { + const { repository, send } = createRepository(); + send.mockResolvedValueOnce({ Parameter: {} }); + + const result = await repository.getApplication("client-1"); + + expect(result).toBeUndefined(); + }); + + it("rethrows unexpected SSM errors", async () => { + const { repository, send } = createRepository(); + send.mockRejectedValueOnce( + Object.assign(new Error("Network failure"), { name: "NetworkError" }), + ); + + await expect(repository.getApplication("client-1")).rejects.toThrow( + "Network failure", + ); + }); + }); + + describe("addApplication", () => { + it("reads existing map, merges new entry, and writes back", async () => { + const { repository, send } = createRepository(); + send + .mockResolvedValueOnce({ + Parameter: { + Value: JSON.stringify({ "existing-client": "existing-app" }), + }, + }) + .mockResolvedValueOnce({}); + + const result = await repository.addApplication("client-1", "app-1"); + + expect(send).toHaveBeenNthCalledWith(1, expect.any(GetParameterCommand)); + expect(send).toHaveBeenNthCalledWith(2, expect.any(PutParameterCommand)); + expect(result).toEqual( + new Map([ + ["existing-client", "existing-app"], + ["client-1", "app-1"], + ]), + ); + }); + + it("starts from empty map when parameter does not exist", async () => { + const { repository, send } = createRepository(); + const error = Object.assign(new Error("not found"), { + name: "ParameterNotFound", + }); + send.mockRejectedValueOnce(error).mockResolvedValueOnce({}); + + const result = await repository.addApplication("client-1", "app-1"); + + expect(result).toEqual(new Map([["client-1", "app-1"]])); + expect(send).toHaveBeenCalledTimes(2); + }); + + it("starts from empty map when parameter has no value", async () => { + const { repository, send } = createRepository(); + send.mockResolvedValueOnce({ Parameter: {} }).mockResolvedValueOnce({}); + + const result = await repository.addApplication("client-1", "app-1"); + + expect(result).toEqual(new Map([["client-1", "app-1"]])); + }); + + it("overwrites an existing client entry", async () => { + const { repository, send } = createRepository(); + send + .mockResolvedValueOnce({ + Parameter: { Value: JSON.stringify({ "client-1": "old-app" }) }, + }) + .mockResolvedValueOnce({}); + + const result = await repository.addApplication("client-1", "new-app"); + + expect(result).toEqual(new Map([["client-1", "new-app"]])); + }); + + it("skips the put when dry-run is true", async () => { + const { repository, send } = createRepository(); + send.mockResolvedValueOnce({ + Parameter: { Value: JSON.stringify({}) }, + }); + + const result = await repository.addApplication("client-1", "app-1", true); + + expect(send).toHaveBeenCalledTimes(1); + expect(result).toEqual(new Map([["client-1", "app-1"]])); + }); + + it("rethrows unexpected SSM errors", async () => { + const { repository, send } = createRepository(); + send.mockRejectedValueOnce( + Object.assign(new Error("Network failure"), { name: "NetworkError" }), + ); + + await expect( + repository.addApplication("client-1", "app-1"), + ).rejects.toThrow("Network failure"); + }); + }); +}); diff --git a/tools/client-subscriptions-management/src/aws.ts b/tools/client-subscriptions-management/src/aws.ts index e2272f35..5599b50b 100644 --- a/tools/client-subscriptions-management/src/aws.ts +++ b/tools/client-subscriptions-management/src/aws.ts @@ -1,7 +1,9 @@ import { S3Client } from "@aws-sdk/client-s3"; +import { SSMClient } from "@aws-sdk/client-ssm"; import { GetCallerIdentityCommand, STSClient } from "@aws-sdk/client-sts"; import { fromIni } from "@aws-sdk/credential-providers"; import { ClientSubscriptionRepository } from "src/repository/client-subscriptions"; +import SsmApplicationsMapRepository from "src/repository/ssm-applications-map"; import { S3Repository } from "src/repository/s3"; export const resolveProfile = ( @@ -87,3 +89,49 @@ export const createRepository = (options: { ); return new ClientSubscriptionRepository(s3Repository); }; + +export const createSsmClient = ( + region?: string, + profile?: string, + env: NodeJS.ProcessEnv = process.env, +): SSMClient => { + const endpoint = env.AWS_ENDPOINT_URL; + const credentials = profile ? fromIni({ profile }) : undefined; + return new SSMClient({ region, endpoint, credentials }); +}; + +export const deriveParameterName = (environment: string): string => + `/nhs/${environment}/callbacks/applications-map`; + +export const resolveParameterName = (args: { + parameterName?: string; + environment?: string; + env?: NodeJS.ProcessEnv; +}): string => { + const { env = process.env, environment, parameterName } = args; + + if (parameterName) { + return parameterName; + } + + const resolvedEnvironment = environment ?? env.ENVIRONMENT; + if (!resolvedEnvironment) { + throw new Error( + "Environment is required to derive parameter name. Please provide via --environment or ENVIRONMENT env var.", + ); + } + + return deriveParameterName(resolvedEnvironment); +}; + +export const createSsmApplicationsMapRepository = (options: { + parameterName: string; + region?: string; + profile?: string; +}): SsmApplicationsMapRepository => + new SsmApplicationsMapRepository( + createSsmClient(options.region, options.profile), + options.parameterName, + ); + +export { default as SsmApplicationsMapRepository } from "src/repository/ssm-applications-map"; diff --git a/tools/client-subscriptions-management/src/entrypoint/cli/applications-map-add.ts b/tools/client-subscriptions-management/src/entrypoint/cli/applications-map-add.ts new file mode 100644 index 00000000..a98e574f --- /dev/null +++ b/tools/client-subscriptions-management/src/entrypoint/cli/applications-map-add.ts @@ -0,0 +1,60 @@ +import type { Argv } from "yargs"; +import { + type CliCommand, + type ClientCliArgs, + type SsmCliArgs, + type WriteCliArgs, + clientIdOption, + commonOptions, + createSsmApplicationsMapRepository, + parameterNameOption, + runCommand, + writeOptions, +} from "src/entrypoint/cli/helper"; +import { formatApplicationsMap } from "src/format"; + +type ApplicationsMapAddArgs = ClientCliArgs & + SsmCliArgs & + WriteCliArgs & { + "application-id": string; + }; + +export const builder = (yargs: Argv) => + yargs.options({ + ...commonOptions, + ...clientIdOption, + ...parameterNameOption, + ...writeOptions, + "application-id": { + type: "string", + demandOption: true, + description: "Application ID to associate with the client", + }, + }); + +export const handler: CliCommand["handler"] = async ( + argv, +) => { + const repository = createSsmApplicationsMapRepository(argv); + const result = await repository.addApplication( + argv["client-id"], + argv["application-id"], + argv["dry-run"], + ); + console.log(`Applications map updated for client '${argv["client-id"]}'.`); + if (argv["dry-run"]) { + console.log("Dry run — no changes written to SSM."); + } + console.log(formatApplicationsMap(result)); +}; + +export const command: CliCommand = { + command: "applications-map-add", + describe: "Add or update a client-to-application-ID mapping in SSM", + builder, + handler, +}; + +export async function main(args: string[] = process.argv) { + await runCommand(command, args); +} diff --git a/tools/client-subscriptions-management/src/entrypoint/cli/applications-map-get.ts b/tools/client-subscriptions-management/src/entrypoint/cli/applications-map-get.ts new file mode 100644 index 00000000..5ffe2192 --- /dev/null +++ b/tools/client-subscriptions-management/src/entrypoint/cli/applications-map-get.ts @@ -0,0 +1,46 @@ +import type { Argv } from "yargs"; +import { + type CliCommand, + type ClientCliArgs, + type SsmCliArgs, + clientIdOption, + commonOptions, + createSsmApplicationsMapRepository, + parameterNameOption, + runCommand, +} from "src/entrypoint/cli/helper"; + +type ApplicationsMapGetArgs = ClientCliArgs & SsmCliArgs; + +export const builder = (yargs: Argv) => + yargs.options({ + ...commonOptions, + ...clientIdOption, + ...parameterNameOption, + }); + +export const handler: CliCommand["handler"] = async ( + argv, +) => { + const repository = createSsmApplicationsMapRepository(argv); + const applicationId = await repository.getApplication(argv["client-id"]); + + if (applicationId) { + console.log(applicationId); + } else { + throw new Error( + `No application mapping exists for client: ${argv["client-id"]}`, + ); + } +}; + +export const command: CliCommand = { + command: "applications-map-get", + describe: "Get the application ID mapped to a client", + builder, + handler, +}; + +export async function main(args: string[] = process.argv) { + await runCommand(command, args); +} diff --git a/tools/client-subscriptions-management/src/entrypoint/cli/helper.ts b/tools/client-subscriptions-management/src/entrypoint/cli/helper.ts index e9a94487..14e998dd 100644 --- a/tools/client-subscriptions-management/src/entrypoint/cli/helper.ts +++ b/tools/client-subscriptions-management/src/entrypoint/cli/helper.ts @@ -1,6 +1,8 @@ import { createRepository as createRepositoryFromOptions, + createSsmApplicationsMapRepository as createSsmApplicationsMapRepositoryFromOptions, resolveBucketName, + resolveParameterName as resolveParameterNameFromAws, resolveProfile, resolveRegion, } from "src/aws"; @@ -130,3 +132,30 @@ export const writeOptions = { description: "Validate config without writing to S3", }, }; + +export type SsmCliArgs = CommonCliArgs & { + "parameter-name"?: string; +}; + +export const parameterNameOption = { + "parameter-name": { + type: "string" as const, + demandOption: false as const, + description: + "Explicit SSM parameter name for the applications map (overrides derived name)", + }, +}; + +export const createSsmApplicationsMapRepository = (argv: SsmCliArgs) => { + const region = resolveRegion(argv.region); + const profile = resolveProfile(argv.profile); + const parameterName = resolveParameterNameFromAws({ + parameterName: argv["parameter-name"], + environment: argv.environment, + }); + return createSsmApplicationsMapRepositoryFromOptions({ + parameterName, + region, + profile, + }); +}; diff --git a/tools/client-subscriptions-management/src/entrypoint/cli/index.ts b/tools/client-subscriptions-management/src/entrypoint/cli/index.ts index 88d1a733..d13a11a8 100644 --- a/tools/client-subscriptions-management/src/entrypoint/cli/index.ts +++ b/tools/client-subscriptions-management/src/entrypoint/cli/index.ts @@ -1,3 +1,5 @@ +import { command as applicationsMapAddCommand } from "src/entrypoint/cli/applications-map-add"; +import { command as applicationsMapGetCommand } from "src/entrypoint/cli/applications-map-get"; import { command as clientsGetCommand } from "src/entrypoint/cli/clients-get"; import { command as clientsListCommand } from "src/entrypoint/cli/clients-list"; import { command as clientsPutCommand } from "src/entrypoint/cli/clients-put"; @@ -12,6 +14,8 @@ import { command as targetsDelCommand } from "src/entrypoint/cli/targets-del"; import { command as targetsListCommand } from "src/entrypoint/cli/targets-list"; export const commands: AnyCliCommand[] = [ + applicationsMapAddCommand, + applicationsMapGetCommand, clientsListCommand, clientsGetCommand, clientsPutCommand, diff --git a/tools/client-subscriptions-management/src/format.ts b/tools/client-subscriptions-management/src/format.ts index 1c944c06..0944fd16 100644 --- a/tools/client-subscriptions-management/src/format.ts +++ b/tools/client-subscriptions-management/src/format.ts @@ -74,3 +74,16 @@ export const formatClientConfig = ( export const normalizeClientName = (name: string): string => name.replaceAll(/\s+/g, "-").toLowerCase(); + +const maskValue = (value: string): string => "*".repeat(value.length || 8); + +export const formatApplicationsMap = (map: Map): string => + map.size === 0 + ? "Applications map: (empty)" + : table([ + ["Client ID", "Application ID"], + ...[...map.entries()].map(([clientId, applicationId]) => [ + clientId, + maskValue(applicationId), + ]), + ]); diff --git a/tools/client-subscriptions-management/src/repository/ssm-applications-map.ts b/tools/client-subscriptions-management/src/repository/ssm-applications-map.ts new file mode 100644 index 00000000..13553d85 --- /dev/null +++ b/tools/client-subscriptions-management/src/repository/ssm-applications-map.ts @@ -0,0 +1,77 @@ +import { + GetParameterCommand, + PutParameterCommand, + type SSMClient, +} from "@aws-sdk/client-ssm"; + +export default class SsmApplicationsMapRepository { + constructor( + private readonly client: SSMClient, + private readonly parameterName: string, + ) {} + + async getApplication(clientId: string): Promise { + try { + const response = await this.client.send( + new GetParameterCommand({ + Name: this.parameterName, + WithDecryption: true, + }), + ); + if (response.Parameter?.Value) { + const map = JSON.parse(response.Parameter.Value) as Record< + string, + string + >; + return map[clientId]; + } + } catch (error) { + if (error instanceof Error && error.name !== "ParameterNotFound") { + throw error; + } + } + return undefined; + } + + async addApplication( + clientId: string, + applicationId: string, + dryRun = false, + ): Promise> { + let current: Record = {}; + + try { + const response = await this.client.send( + new GetParameterCommand({ + Name: this.parameterName, + WithDecryption: true, + }), + ); + if (response.Parameter?.Value) { + current = JSON.parse(response.Parameter.Value) as Record< + string, + string + >; + } + } catch (error) { + if (error instanceof Error && error.name !== "ParameterNotFound") { + throw error; + } + } + + const updated = { ...current, [clientId]: applicationId }; + + if (!dryRun) { + await this.client.send( + new PutParameterCommand({ + Name: this.parameterName, + Value: JSON.stringify(updated), + Type: "SecureString", + Overwrite: true, + }), + ); + } + + return new Map(Object.entries(updated)); + } +}