@@ -72,7 +72,8 @@ pub use error::AsyncHttpRangeReaderError;
7272/// if response.status() == reqwest::StatusCode::NOT_MODIFIED {
7373/// Ok(None)
7474/// } else {
75- /// let reader = AsyncHttpRangeReader::from_head_response(client, response, HeaderMap::default()).await?;
75+ /// let url = response.url().clone();
76+ /// let reader = AsyncHttpRangeReader::from_head_response(client, response, url, HeaderMap::default()).await?;
7677/// Ok(Some(reader))
7778/// }
7879/// }
@@ -131,6 +132,15 @@ pub enum CheckSupportMethod {
131132 Head ,
132133}
133134
135+ /// Which URL should be used for subsequent range requests?
136+ pub enum RangeRequestUrlSource {
137+ /// Use the initial request URL
138+ Request ,
139+
140+ /// Use the initial response URL
141+ Response ,
142+ }
143+
134144fn error_for_status ( response : reqwest:: Response ) -> reqwest_middleware:: Result < Response > {
135145 response
136146 . error_for_status ( )
@@ -143,6 +153,7 @@ impl AsyncHttpRangeReader {
143153 client : impl Into < reqwest_middleware:: ClientWithMiddleware > ,
144154 url : reqwest:: Url ,
145155 check_method : CheckSupportMethod ,
156+ range_request_url_source : RangeRequestUrlSource ,
146157 extra_headers : HeaderMap ,
147158 ) -> Result < ( Self , HeaderMap ) , AsyncHttpRangeReaderError > {
148159 let client = client. into ( ) ;
@@ -156,15 +167,23 @@ impl AsyncHttpRangeReader {
156167 )
157168 . await ?;
158169 let response_headers = response. headers ( ) . clone ( ) ;
159- let self_ = Self :: from_tail_response ( client, response, extra_headers) . await ?;
170+ let url = match range_request_url_source {
171+ RangeRequestUrlSource :: Request => url,
172+ RangeRequestUrlSource :: Response => response. url ( ) . clone ( ) ,
173+ } ;
174+ let self_ = Self :: from_tail_response ( client, response, url, extra_headers) . await ?;
160175 Ok ( ( self_, response_headers) )
161176 }
162177 CheckSupportMethod :: Head => {
163178 let response =
164179 Self :: initial_head_request ( client. clone ( ) , url. clone ( ) , HeaderMap :: default ( ) )
165180 . await ?;
166181 let response_headers = response. headers ( ) . clone ( ) ;
167- let self_ = Self :: from_head_response ( client, response, extra_headers) . await ?;
182+ let url = match range_request_url_source {
183+ RangeRequestUrlSource :: Request => url,
184+ RangeRequestUrlSource :: Response => response. url ( ) . clone ( ) ,
185+ } ;
186+ let self_ = Self :: from_head_response ( client, response, url, extra_headers) . await ?;
168187 Ok ( ( self_, response_headers) )
169188 }
170189 }
@@ -200,6 +219,7 @@ impl AsyncHttpRangeReader {
200219 pub async fn from_tail_response (
201220 client : impl Into < reqwest_middleware:: ClientWithMiddleware > ,
202221 tail_request_response : Response ,
222+ url : Url ,
203223 extra_headers : HeaderMap ,
204224 ) -> Result < Self , AsyncHttpRangeReaderError > {
205225 let client = client. into ( ) ;
@@ -245,7 +265,7 @@ impl AsyncHttpRangeReader {
245265 let ( state_tx, state_rx) = watch:: channel ( StreamerState :: default ( ) ) ;
246266 tokio:: spawn ( run_streamer (
247267 client,
248- tail_request_response . url ( ) . clone ( ) ,
268+ url,
249269 extra_headers,
250270 Some ( ( tail_request_response, start) ) ,
251271 memory_map,
@@ -300,6 +320,7 @@ impl AsyncHttpRangeReader {
300320 pub async fn from_head_response (
301321 client : impl Into < reqwest_middleware:: ClientWithMiddleware > ,
302322 head_response : Response ,
323+ url : Url ,
303324 extra_headers : HeaderMap ,
304325 ) -> Result < Self , AsyncHttpRangeReaderError > {
305326 let client = client. into ( ) ;
@@ -345,7 +366,7 @@ impl AsyncHttpRangeReader {
345366 let ( state_tx, state_rx) = watch:: channel ( StreamerState :: default ( ) ) ;
346367 tokio:: spawn ( run_streamer (
347368 client,
348- head_response . url ( ) . clone ( ) ,
369+ url,
349370 extra_headers,
350371 None ,
351372 memory_map,
@@ -688,6 +709,7 @@ mod test {
688709 Client :: new ( ) ,
689710 server. url ( ) . join ( "andes-1.8.3-pyhd8ed1ab_0.conda" ) . unwrap ( ) ,
690711 check_method,
712+ RangeRequestUrlSource :: Response ,
691713 HeaderMap :: default ( ) ,
692714 )
693715 . await
@@ -728,7 +750,7 @@ mod test {
728750 ) ;
729751
730752 // Prefetch the data for the metadata.json file
731- let entry = reader. file ( ) . entries ( ) . get ( 0 ) . unwrap ( ) ;
753+ let entry = reader. file ( ) . entries ( ) . first ( ) . unwrap ( ) ;
732754 let offset = entry. header_offset ( ) ;
733755 // Get the size of the entry plus the header + size of the filename. We should also actually
734756 // include bytes for the extra fields but we don't have that information.
@@ -783,6 +805,57 @@ mod test {
783805 Client :: new ( ) ,
784806 server. url ( ) . join ( "andes-1.8.3-pyhd8ed1ab_0.conda" ) . unwrap ( ) ,
785807 check_method,
808+ RangeRequestUrlSource :: Response ,
809+ HeaderMap :: default ( ) ,
810+ )
811+ . await
812+ . expect ( "bla" ) ;
813+
814+ // Also open a simple file reader
815+ let mut file = tokio:: fs:: File :: open ( path. join ( "andes-1.8.3-pyhd8ed1ab_0.conda" ) )
816+ . await
817+ . unwrap ( ) ;
818+
819+ // Read until the end and make sure that the contents matches
820+ let mut range_read = vec ! [ 0 ; 64 * 1024 ] ;
821+ let mut file_read = vec ! [ 0 ; 64 * 1024 ] ;
822+ loop {
823+ // Read with the async reader
824+ let range_read_bytes = range. read ( & mut range_read) . await . unwrap ( ) ;
825+
826+ // Read directly from the file
827+ let file_read_bytes = file
828+ . read_exact ( & mut file_read[ 0 ..range_read_bytes] )
829+ . await
830+ . unwrap ( ) ;
831+
832+ assert_eq ! ( range_read_bytes, file_read_bytes) ;
833+ assert_eq ! (
834+ range_read[ 0 ..range_read_bytes] ,
835+ file_read[ 0 ..file_read_bytes]
836+ ) ;
837+
838+ if file_read_bytes == 0 && range_read_bytes == 0 {
839+ break ;
840+ }
841+ }
842+ }
843+
844+ #[ rstest]
845+ #[ case( RangeRequestUrlSource :: Request ) ]
846+ #[ case( RangeRequestUrlSource :: Response ) ]
847+ #[ tokio:: test]
848+ async fn async_range_reader_url_source ( #[ case] url_source : RangeRequestUrlSource ) {
849+ // Spawn a static file server
850+ let path = Path :: new ( & std:: env:: var ( "CARGO_MANIFEST_DIR" ) . unwrap ( ) ) . join ( "test-data" ) ;
851+ let server = StaticDirectoryServer :: new ( & path) ;
852+
853+ // Construct an AsyncRangeReader
854+ let ( mut range, _) = AsyncHttpRangeReader :: new (
855+ Client :: new ( ) ,
856+ server. url ( ) . join ( "andes-1.8.3-pyhd8ed1ab_0.conda" ) . unwrap ( ) ,
857+ CheckSupportMethod :: Head ,
858+ url_source,
786859 HeaderMap :: default ( ) ,
787860 )
788861 . await
@@ -825,6 +898,7 @@ mod test {
825898 Client :: new ( ) ,
826899 server. url ( ) . join ( "not-found" ) . unwrap ( ) ,
827900 CheckSupportMethod :: Head ,
901+ RangeRequestUrlSource :: Response ,
828902 HeaderMap :: default ( ) ,
829903 )
830904 . await
0 commit comments