diff --git a/clients/aws-sdk-connecthealth/tests/integration/__init__.py b/clients/aws-sdk-connecthealth/tests/integration/__init__.py new file mode 100644 index 0000000..10673d0 --- /dev/null +++ b/clients/aws-sdk-connecthealth/tests/integration/__init__.py @@ -0,0 +1,32 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +from smithy_aws_core.identity import EnvironmentCredentialsResolver + +from aws_sdk_connecthealth.client import ConnectHealthClient +from aws_sdk_connecthealth.config import Config, Plugin + +REGION = "us-east-1" +AUDIO_FILE = Path(__file__).parent / "assets" / "test.wav" + + +def create_connecthealth_client(region: str) -> ConnectHealthClient: + return ConnectHealthClient( + config=Config( + endpoint_uri=f"https://health-agent.{region}.api.aws", + region=region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + ) + ) + + +def streaming_endpoint_plugin(region: str) -> Plugin: + """Per-operation plugin that routes to the ``streaming.`` host prefix.""" + streaming_uri = f"https://streaming.health-agent.{region}.api.aws" + + def _plugin(config: Config) -> None: + config.endpoint_uri = streaming_uri + + return _plugin diff --git a/clients/aws-sdk-connecthealth/tests/integration/assets/test.wav b/clients/aws-sdk-connecthealth/tests/integration/assets/test.wav new file mode 100644 index 0000000..5213ec9 Binary files /dev/null and b/clients/aws-sdk-connecthealth/tests/integration/assets/test.wav differ diff --git a/clients/aws-sdk-connecthealth/tests/integration/conftest.py b/clients/aws-sdk-connecthealth/tests/integration/conftest.py new file mode 100644 index 0000000..4beadef --- /dev/null +++ b/clients/aws-sdk-connecthealth/tests/integration/conftest.py @@ -0,0 +1,198 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Pytest fixtures for ConnectHealth integration tests. + +Creates and tears down a ConnectHealth Domain with an ACTIVE Subscription +plus an S3 bucket once per test session. The ``connecthealth_resources`` +fixture provides ``(domain_id, subscription_id, output_s3_uri)``. +""" + +import uuid +from typing import Any + +import boto3 +import pytest +from botocore.waiter import WaiterModel, create_waiter_with_client + +from . import REGION + +# Tags applied to all resources so orphaned resources from interrupted +# test runs can be discovered and cleaned up. +_TAGS = [{"Key": "Purpose", "Value": "IntegTest"}] + +_WAITER_CONFIG = { + "version": 2, + "waiters": { + "DomainActive": { + "operation": "GetDomain", + "delay": 5, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "expected": "ACTIVE", + "argument": "status", + "state": "success", + }, + { + "matcher": "path", + "expected": "DELETING", + "argument": "status", + "state": "failure", + }, + { + "matcher": "path", + "expected": "DELETED", + "argument": "status", + "state": "failure", + }, + ], + }, + "SubscriptionActive": { + "operation": "GetSubscription", + "delay": 5, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "expected": "ACTIVE", + "argument": "subscription.status", + "state": "success", + }, + { + "matcher": "path", + "expected": "DELETED", + "argument": "subscription.status", + "state": "failure", + }, + ], + }, + "SubscriptionInactive": { + "operation": "GetSubscription", + "delay": 5, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "expected": "INACTIVE", + "argument": "subscription.status", + "state": "success", + }, + { + "matcher": "path", + "expected": "DELETED", + "argument": "subscription.status", + "state": "failure", + }, + ], + }, + }, +} +_waiter_model = WaiterModel(_WAITER_CONFIG) + + +def _create_connecthealth_resources( + s3_client: Any, + connecthealth_client: Any, + domain_name: str, + bucket_name: str, +) -> tuple[str, str]: + """Create an S3 bucket, ConnectHealth Domain, and ACTIVE Subscription. + + Args: + s3_client: A boto3 S3 client. + connecthealth_client: A boto3 ConnectHealth client. + domain_name: The name of the Domain to create. + bucket_name: The name of the S3 bucket to create. + + Returns: + Tuple of (domain_id, subscription_id). + """ + s3_client.create_bucket(Bucket=bucket_name) + s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": _TAGS}) + + response = connecthealth_client.create_domain( + name=domain_name, + tags={t["Key"]: t["Value"] for t in _TAGS}, + ) + domain_id = response["domainId"] + create_waiter_with_client( + "DomainActive", _waiter_model, connecthealth_client + ).wait(domainId=domain_id) + + response = connecthealth_client.create_subscription(domainId=domain_id) + subscription_id = response["subscriptionId"] + create_waiter_with_client( + "SubscriptionActive", _waiter_model, connecthealth_client + ).wait(domainId=domain_id, subscriptionId=subscription_id) + + return domain_id, subscription_id + + +def _delete_connecthealth_resources( + s3_client: Any, + connecthealth_client: Any, + domain_id: str | None, + subscription_id: str | None, + bucket_name: str, +) -> None: + """Deactivate the Subscription, then delete the Domain and S3 bucket. + + Args: + s3_client: A boto3 S3 client. + connecthealth_client: A boto3 ConnectHealth client. + domain_id: The Domain ID to delete, or None if creation failed. + subscription_id: The Subscription ID to deactivate, or None if + creation failed. + bucket_name: The name of the S3 bucket to delete. + """ + if domain_id and subscription_id: + connecthealth_client.deactivate_subscription( + domainId=domain_id, subscriptionId=subscription_id + ) + create_waiter_with_client( + "SubscriptionInactive", _waiter_model, connecthealth_client + ).wait(domainId=domain_id, subscriptionId=subscription_id) + + if domain_id: + connecthealth_client.delete_domain(domainId=domain_id) + + # Empty and delete the bucket + try: + paginator = s3_client.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=bucket_name): + objects = page.get("Contents") + if not objects: + continue + s3_client.delete_objects( + Bucket=bucket_name, + Delete={"Objects": [{"Key": o["Key"]} for o in objects]}, + ) + s3_client.delete_bucket(Bucket=bucket_name) + except s3_client.exceptions.NoSuchBucket: + pass + + +@pytest.fixture(scope="session") +def connecthealth_resources(): + """Create ConnectHealth resources for the test session and delete them after.""" + unique_suffix = uuid.uuid4().hex[:16] + domain_name = f"integ-test-connecthealth-domain-{unique_suffix}" + bucket_name = f"integ-test-connecthealth-bucket-{unique_suffix}" + + s3_client = boto3.client("s3", region_name=REGION) + connecthealth_client = boto3.client("connecthealth", region_name=REGION) + + domain_id = None + subscription_id = None + try: + domain_id, subscription_id = _create_connecthealth_resources( + s3_client, connecthealth_client, domain_name, bucket_name + ) + output_s3_uri = f"s3://{bucket_name}/clinical-notes/" + yield domain_id, subscription_id, output_s3_uri + finally: + _delete_connecthealth_resources( + s3_client, connecthealth_client, domain_id, subscription_id, bucket_name + ) diff --git a/clients/aws-sdk-connecthealth/tests/integration/test_bidirectional_streaming.py b/clients/aws-sdk-connecthealth/tests/integration/test_bidirectional_streaming.py new file mode 100644 index 0000000..d0e9eda --- /dev/null +++ b/clients/aws-sdk-connecthealth/tests/integration/test_bidirectional_streaming.py @@ -0,0 +1,215 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Test bidirectional event stream handling.""" + +import asyncio +import time +import uuid + +from smithy_core.aio.eventstream import DuplexEventStream + +from aws_sdk_connecthealth.models import ( + ClinicalNoteGenerationSettings, + ClinicalNoteGenerationSettingsResponse, + EncounterContext, + GetMedicalScribeListeningSessionInput, + GetMedicalScribeListeningSessionOutput, + ManagedTemplate, + ManagedNoteTemplate, + MedicalScribeAudioEvent, + MedicalScribeConfigurationEvent, + MedicalScribeInputStream, + MedicalScribeInputStreamAudioEvent, + MedicalScribeInputStreamConfigurationEvent, + MedicalScribeInputStreamSessionControlEvent, + MedicalScribeLanguageCode, + MedicalScribeMediaEncoding, + MedicalScribeOutputStream, + MedicalScribeOutputStreamTranscriptEvent, + MedicalScribePostStreamActionSettings, + MedicalScribePostStreamActionSettingsResponse, + MedicalScribeSessionControlEvent, + MedicalScribeSessionControlEventType, + MedicalScribeStreamStatus, + NoteTemplateSettingsManagedTemplate, + NoteTemplateSettingsResponseManagedTemplate, + StartMedicalScribeListeningSessionInput, + StartMedicalScribeListeningSessionOutput, +) + +from . import AUDIO_FILE, REGION, create_connecthealth_client, streaming_endpoint_plugin + + +SAMPLE_RATE = 16000 +BYTES_PER_SAMPLE = 2 +CHANNEL_NUMS = 1 +CHUNK_SIZE = 1024 * 8 + + +async def _send_events( + stream: DuplexEventStream[ + MedicalScribeInputStream, + MedicalScribeOutputStream, + StartMedicalScribeListeningSessionOutput, + ], + output_s3_uri: str, +) -> None: + """Send configuration, audio chunks, and end-of-session control event.""" + await stream.input_stream.send( + MedicalScribeInputStreamConfigurationEvent( + value=MedicalScribeConfigurationEvent( + post_stream_action_settings=MedicalScribePostStreamActionSettings( + output_s3_uri=output_s3_uri, + clinical_note_generation_settings=ClinicalNoteGenerationSettings( + note_template_settings=NoteTemplateSettingsManagedTemplate( + value=ManagedTemplate( + template_type=ManagedNoteTemplate.HISTORY_AND_PHYSICAL + ) + ) + ), + ), + encounter_context=EncounterContext( + unstructured_context="Integration test encounter for SDK validation." + ), + ) + ) + ) + + start_time = time.time() + elapsed_audio_time = 0.0 + with AUDIO_FILE.open("rb") as f: + while chunk := f.read(CHUNK_SIZE): + await stream.input_stream.send( + MedicalScribeInputStreamAudioEvent( + value=MedicalScribeAudioEvent(audio_chunk=chunk) + ) + ) + elapsed_audio_time += len(chunk) / ( + BYTES_PER_SAMPLE * SAMPLE_RATE * CHANNEL_NUMS + ) + wait_time = start_time + elapsed_audio_time - time.time() + if wait_time > 0: + await asyncio.sleep(wait_time) + + await stream.input_stream.send( + MedicalScribeInputStreamSessionControlEvent( + value=MedicalScribeSessionControlEvent( + type=MedicalScribeSessionControlEventType.END_OF_SESSION + ) + ) + ) + await stream.input_stream.close() + + +async def _receive_events( + stream: DuplexEventStream[ + MedicalScribeInputStream, + MedicalScribeOutputStream, + StartMedicalScribeListeningSessionOutput, + ], + expected_session_id: str, + expected_domain_id: str, + expected_subscription_id: str, +) -> bool: + """Receive and assert per-event-type fields. + + Returns: + True if at least one transcript event with non-empty content was + received. + """ + got_transcript = False + + start_output, output_stream = await stream.await_output() + + assert isinstance(start_output, StartMedicalScribeListeningSessionOutput) + assert start_output.session_id == expected_session_id + assert start_output.domain_id == expected_domain_id + assert start_output.subscription_id == expected_subscription_id + assert start_output.request_id is not None + assert start_output.language_code == MedicalScribeLanguageCode.EN_US + assert start_output.media_encoding == MedicalScribeMediaEncoding.PCM + assert start_output.media_sample_rate_hertz == SAMPLE_RATE + + if output_stream is None: + return got_transcript + + async for event in output_stream: + if isinstance(event, MedicalScribeOutputStreamTranscriptEvent): + segment = event.value.transcript_segment + assert segment is not None + assert segment.segment_id is not None + assert segment.audio_begin_offset is not None + assert segment.audio_end_offset is not None + assert segment.is_partial is not None + assert segment.channel_id is not None + if segment.content: + got_transcript = True + else: + raise RuntimeError( + f"Received unexpected event type in stream: {type(event).__name__}" + ) + + return got_transcript + + +async def test_start_medical_scribe_listening_session(connecthealth_resources) -> None: + """Test bidirectional streaming with audio input and transcript output.""" + domain_id, subscription_id, output_s3_uri = connecthealth_resources + + client = create_connecthealth_client(REGION) + streaming_plugin = streaming_endpoint_plugin(REGION) + session_id = str(uuid.uuid4()) + + stream = await client.start_medical_scribe_listening_session( + input=StartMedicalScribeListeningSessionInput( + session_id=session_id, + domain_id=domain_id, + subscription_id=subscription_id, + language_code=MedicalScribeLanguageCode.EN_US, + media_sample_rate_hertz=SAMPLE_RATE, + media_encoding=MedicalScribeMediaEncoding.PCM, + ), + plugins=[streaming_plugin], + ) + + results = await asyncio.gather( + _send_events(stream, output_s3_uri), + _receive_events(stream, session_id, domain_id, subscription_id), + ) + got_transcript = results[1] + assert got_transcript, ( + "Expected to receive a transcript event with non-empty content" + ) + + response = await client.get_medical_scribe_listening_session( + input=GetMedicalScribeListeningSessionInput( + session_id=session_id, domain_id=domain_id, subscription_id=subscription_id + ), + plugins=[streaming_plugin], + ) + assert isinstance(response, GetMedicalScribeListeningSessionOutput) + details = response.medical_scribe_listening_session_details + assert details is not None + assert details.session_id == session_id + assert details.stream_status == MedicalScribeStreamStatus.COMPLETED + assert details.language_code == MedicalScribeLanguageCode.EN_US + assert details.media_encoding == MedicalScribeMediaEncoding.PCM + assert details.media_sample_rate_hertz == SAMPLE_RATE + assert details.encounter_context_provided is True + assert isinstance( + details.post_stream_action_settings, + MedicalScribePostStreamActionSettingsResponse, + ) + assert details.post_stream_action_settings.output_s3_uri == output_s3_uri + assert isinstance( + details.post_stream_action_settings.clinical_note_generation_settings, + ClinicalNoteGenerationSettingsResponse, + ) + note_template = details.post_stream_action_settings.clinical_note_generation_settings.note_template_settings + assert isinstance(note_template, NoteTemplateSettingsResponseManagedTemplate) + assert note_template.value is not None + assert note_template.value.template_type == ManagedNoteTemplate.HISTORY_AND_PHYSICAL + assert details.post_stream_action_result is not None + assert details.stream_creation_time is not None + assert details.stream_end_time is not None diff --git a/clients/aws-sdk-connecthealth/tests/integration/test_non_streaming.py b/clients/aws-sdk-connecthealth/tests/integration/test_non_streaming.py new file mode 100644 index 0000000..3365dfd --- /dev/null +++ b/clients/aws-sdk-connecthealth/tests/integration/test_non_streaming.py @@ -0,0 +1,88 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Test non-streaming operations for ConnectHealth.""" + +from aws_sdk_connecthealth.models import ( + DomainStatus, + EncryptionType, + GetDomainInput, + GetDomainOutput, + GetSubscriptionInput, + GetSubscriptionOutput, + ListDomainsInput, + ListDomainsOutput, + ListSubscriptionsInput, + ListSubscriptionsOutput, + SubscriptionStatus, +) + +from . import REGION, create_connecthealth_client + + +async def test_list_domains(connecthealth_resources) -> None: + """Test non-streaming ListDomains operation.""" + _ = connecthealth_resources + client = create_connecthealth_client(REGION) + + response = await client.list_domains(input=ListDomainsInput()) + + assert isinstance(response, ListDomainsOutput) + assert response.domains is not None + assert len(response.domains) >= 1 + + +async def test_get_domain(connecthealth_resources) -> None: + """Test non-streaming GetDomain operation.""" + domain_id, _, _ = connecthealth_resources + client = create_connecthealth_client(REGION) + + response = await client.get_domain(input=GetDomainInput(domain_id=domain_id)) + + assert isinstance(response, GetDomainOutput) + assert response.domain_id == domain_id + assert response.status == DomainStatus.ACTIVE + assert response.name is not None + assert response.name.startswith("integ-test-connecthealth-domain-") + assert response.arn is not None + assert response.created_at is not None + assert response.encryption_context is not None + assert response.encryption_context.encryption_type == EncryptionType.AWS_OWNED_KEY + assert response.tags == {"Purpose": "IntegTest"} + + +async def test_list_subscriptions(connecthealth_resources) -> None: + """Test non-streaming ListSubscriptions operation.""" + domain_id, subscription_id, _ = connecthealth_resources + client = create_connecthealth_client(REGION) + + response = await client.list_subscriptions( + input=ListSubscriptionsInput(domain_id=domain_id) + ) + + assert isinstance(response, ListSubscriptionsOutput) + assert response.subscriptions is not None + assert len(response.subscriptions) == 1 + + sub = response.subscriptions[0] + assert sub.subscription_id == subscription_id + assert sub.domain_id == domain_id + + +async def test_get_subscription(connecthealth_resources) -> None: + """Test non-streaming GetSubscription operation.""" + domain_id, subscription_id, _ = connecthealth_resources + client = create_connecthealth_client(REGION) + + response = await client.get_subscription( + input=GetSubscriptionInput(domain_id=domain_id, subscription_id=subscription_id) + ) + + assert isinstance(response, GetSubscriptionOutput) + assert response.subscription is not None + assert response.subscription.subscription_id == subscription_id + assert response.subscription.domain_id == domain_id + assert response.subscription.status == SubscriptionStatus.ACTIVE + assert response.subscription.arn is not None + assert response.subscription.created_at is not None + assert response.subscription.last_updated_at is not None