22
33import os
44from datetime import timedelta
5- from typing import Literal , cast
5+ from typing import Literal , Protocol , cast
66
77import pytest
88import requests
99from dotenv import load_dotenv
1010
1111load_dotenv ()
1212
13+ _RequestMethod = Literal ["GET" , "POST" ]
1314
14- class Client :
15- """A simple HTTP client for testing purposes."""
1615
17- def __init__ (self , lambda_url : str , timeout : timedelta = timedelta (seconds = 1 )):
18- self ._lambda_url = lambda_url
19- self ._timeout = timeout .total_seconds ()
16+ class Client (Protocol ):
17+ """Protocol defining the interface for HTTP clients."""
2018
2119 def send (
22- self , data : str , path : str , request_method : Literal [ "GET" , "POST" ]
20+ self , data : str , path : str , request_method : _RequestMethod
2321 ) -> requests .Response :
2422 """
2523 Send a request to the APIs with some given parameters.
@@ -30,12 +28,10 @@ def send(
3028 Returns:
3129 Response object from the request
3230 """
33- return self ._send (
34- data = data , path = path , include_payload = True , request_method = request_method
35- )
31+ ...
3632
3733 def send_without_payload (
38- self , path : str , request_method : Literal [ "GET" , "POST" ]
34+ self , path : str , request_method : _RequestMethod
3935 ) -> requests .Response :
4036 """
4137 Send a request to the APIs without a payload.
@@ -45,6 +41,28 @@ def send_without_payload(
4541 Returns:
4642 Response object from the request
4743 """
44+ ...
45+
46+
47+ class LocalClient :
48+ """A simple HTTP client for testing purposes."""
49+
50+ def __init__ (self , lambda_url : str , timeout : timedelta = timedelta (seconds = 1 )):
51+ self ._lambda_url = lambda_url
52+ self ._timeout = timeout .total_seconds ()
53+
54+ def send (
55+ self , data : str , path : str , request_method : _RequestMethod
56+ ) -> requests .Response :
57+
58+ return self ._send (
59+ data = data , path = path , include_payload = True , request_method = request_method
60+ )
61+
62+ def send_without_payload (
63+ self , path : str , request_method : _RequestMethod
64+ ) -> requests .Response :
65+
4866 return self ._send (
4967 data = None , path = path , include_payload = False , request_method = request_method
5068 )
@@ -54,43 +72,149 @@ def _send(
5472 data : str | None ,
5573 path : str ,
5674 include_payload : bool ,
57- request_method : Literal [ "GET" , "POST" ] ,
75+ request_method : _RequestMethod ,
5876 ) -> requests .Response :
77+ url = f"{ self ._lambda_url } /{ path } "
5978 match request_method :
6079 case "POST" :
6180 return requests .post (
62- f" { self . _lambda_url } / { path } " ,
81+ url ,
6382 data = data if include_payload else None ,
6483 timeout = self ._timeout ,
6584 )
6685 case "GET" :
6786 return requests .get (
68- f"{ self ._lambda_url } /{ path } " ,
69- timeout = self ._timeout ,
87+ url ,
7088 data = data if include_payload else None ,
89+ timeout = self ._timeout ,
7190 )
7291
7392
74- @pytest .fixture (scope = "module" )
75- def client (base_url : str ) -> Client :
76- """Create a test client for the application."""
77- return Client (lambda_url = base_url )
93+ class RemoteClient :
94+ """HTTP client for remote testing."""
95+
96+ def __init__ (
97+ self ,
98+ api_url : str ,
99+ auth_headers : dict [str , str ],
100+ timeout : timedelta = timedelta (seconds = 5 ),
101+ ):
102+ self ._api_url = api_url
103+ self ._default_headers = auth_headers | {"Content-Type" : "application/fhir+json" }
104+ self ._timeout = timeout .total_seconds ()
105+
106+ def send (
107+ self ,
108+ data : str ,
109+ path : str ,
110+ request_method : _RequestMethod ,
111+ headers : dict [str , str ] | None = None ,
112+ ) -> requests .Response :
113+
114+ return self ._send (
115+ data = data ,
116+ path = path ,
117+ include_payload = True ,
118+ request_method = request_method ,
119+ headers = headers ,
120+ )
121+
122+ def send_without_payload (
123+ self ,
124+ path : str ,
125+ request_method : _RequestMethod ,
126+ headers : dict [str , str ] | None = None ,
127+ ) -> requests .Response :
78128
129+ return self ._send (
130+ data = None ,
131+ path = path ,
132+ include_payload = False ,
133+ request_method = request_method ,
134+ headers = headers ,
135+ )
79136
80- @pytest .fixture (scope = "module" )
137+ def _send (
138+ self ,
139+ data : str | None ,
140+ path : str ,
141+ include_payload : bool ,
142+ request_method : _RequestMethod ,
143+ headers : dict [str , str ] | None = None ,
144+ ) -> requests .Response :
145+ url = f"{ self ._api_url } /{ path } "
146+ merged_headers = self ._default_headers | (headers or {})
147+ match request_method :
148+ case "POST" :
149+ return requests .post (
150+ url ,
151+ data = data if include_payload else None ,
152+ headers = merged_headers ,
153+ timeout = self ._timeout ,
154+ )
155+ case "GET" :
156+ return requests .get (
157+ url ,
158+ data = data if include_payload else None ,
159+ headers = merged_headers ,
160+ timeout = self ._timeout ,
161+ )
162+
163+
164+ @pytest .fixture
81165def base_url () -> str :
82166 """Retrieves the base URL of the currently deployed application."""
83167 return _fetch_env_variable ("BASE_URL" , str )
84168
85169
86- @pytest .fixture ( scope = "module" )
170+ @pytest .fixture
87171def hostname () -> str :
88172 """Retrieves the hostname of the currently deployed application."""
89173 return _fetch_env_variable ("HOST" , str )
90174
91175
176+ @pytest .fixture
177+ def client (request : pytest .FixtureRequest , base_url : str ) -> Client :
178+ env = request .config .getoption ("--env" )
179+
180+ if env == "local" :
181+ return LocalClient (lambda_url = base_url )
182+ elif env == "remote" :
183+ auth_headers = request .getfixturevalue ("nhsd_apim_auth_headers" )
184+ proxy_url = request .getfixturevalue ("nhsd_apim_proxy_url" )
185+ return RemoteClient (
186+ api_url = proxy_url ,
187+ auth_headers = auth_headers ,
188+ )
189+ else :
190+ raise ValueError (f"Unknown env: { env } " )
191+
192+
92193def _fetch_env_variable [T ](name : str , _ : type [T ]) -> T :
93194 value = os .getenv (name )
94195 if not value :
95196 raise ValueError (f"{ name } environment variable is not set." )
96197 return cast ("T" , value )
198+
199+
200+ def pytest_addoption (parser : pytest .Parser ) -> None :
201+ parser .addoption (
202+ "--env" ,
203+ action = "store" ,
204+ default = "local" ,
205+ help = "Environment to run tests against" ,
206+ )
207+
208+
209+ def pytest_collection_modifyitems (
210+ config : pytest .Config , items : list [pytest .Item ]
211+ ) -> None :
212+ env = config .getoption ("--env" )
213+
214+ if env == "remote" :
215+ for item in items :
216+ item .add_marker (
217+ pytest .mark .nhsd_apim_authorization (
218+ access = "application" , level = "level3"
219+ )
220+ )
0 commit comments