Skip to content

Commit 87099b9

Browse files
committed
Allow the request URL to be used for subsequent responses
1 parent da28771 commit 87099b9

1 file changed

Lines changed: 80 additions & 6 deletions

File tree

src/lib.rs

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ pub use error::AsyncHttpRangeReaderError;
7272
/// if response.status() == reqwest::StatusCode::NOT_MODIFIED {
7373
/// Ok(None)
7474
/// } else {
75-
/// let reader = AsyncHttpRangeReader::from_head_response(client, response, HeaderMap::default()).await?;
75+
/// let url = response.url().clone();
76+
/// let reader = AsyncHttpRangeReader::from_head_response(client, response, url, HeaderMap::default()).await?;
7677
/// Ok(Some(reader))
7778
/// }
7879
/// }
@@ -131,6 +132,15 @@ pub enum CheckSupportMethod {
131132
Head,
132133
}
133134

135+
/// Which URL should be used for subsequent range requests?
136+
pub enum RangeRequestUrlSource {
137+
/// Use the initial request URL
138+
Request,
139+
140+
/// Use the initial response URL
141+
Response,
142+
}
143+
134144
fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
135145
response
136146
.error_for_status()
@@ -143,6 +153,7 @@ impl AsyncHttpRangeReader {
143153
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
144154
url: reqwest::Url,
145155
check_method: CheckSupportMethod,
156+
range_request_url_source: RangeRequestUrlSource,
146157
extra_headers: HeaderMap,
147158
) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> {
148159
let client = client.into();
@@ -156,15 +167,23 @@ impl AsyncHttpRangeReader {
156167
)
157168
.await?;
158169
let response_headers = response.headers().clone();
159-
let self_ = Self::from_tail_response(client, response, extra_headers).await?;
170+
let url = match range_request_url_source {
171+
RangeRequestUrlSource::Request => url,
172+
RangeRequestUrlSource::Response => response.url().clone(),
173+
};
174+
let self_ = Self::from_tail_response(client, response, url, extra_headers).await?;
160175
Ok((self_, response_headers))
161176
}
162177
CheckSupportMethod::Head => {
163178
let response =
164179
Self::initial_head_request(client.clone(), url.clone(), HeaderMap::default())
165180
.await?;
166181
let response_headers = response.headers().clone();
167-
let self_ = Self::from_head_response(client, response, extra_headers).await?;
182+
let url = match range_request_url_source {
183+
RangeRequestUrlSource::Request => url,
184+
RangeRequestUrlSource::Response => response.url().clone(),
185+
};
186+
let self_ = Self::from_head_response(client, response, url, extra_headers).await?;
168187
Ok((self_, response_headers))
169188
}
170189
}
@@ -200,6 +219,7 @@ impl AsyncHttpRangeReader {
200219
pub async fn from_tail_response(
201220
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
202221
tail_request_response: Response,
222+
url: Url,
203223
extra_headers: HeaderMap,
204224
) -> Result<Self, AsyncHttpRangeReaderError> {
205225
let client = client.into();
@@ -245,7 +265,7 @@ impl AsyncHttpRangeReader {
245265
let (state_tx, state_rx) = watch::channel(StreamerState::default());
246266
tokio::spawn(run_streamer(
247267
client,
248-
tail_request_response.url().clone(),
268+
url,
249269
extra_headers,
250270
Some((tail_request_response, start)),
251271
memory_map,
@@ -300,6 +320,7 @@ impl AsyncHttpRangeReader {
300320
pub async fn from_head_response(
301321
client: impl Into<reqwest_middleware::ClientWithMiddleware>,
302322
head_response: Response,
323+
url: Url,
303324
extra_headers: HeaderMap,
304325
) -> Result<Self, AsyncHttpRangeReaderError> {
305326
let client = client.into();
@@ -345,7 +366,7 @@ impl AsyncHttpRangeReader {
345366
let (state_tx, state_rx) = watch::channel(StreamerState::default());
346367
tokio::spawn(run_streamer(
347368
client,
348-
head_response.url().clone(),
369+
url,
349370
extra_headers,
350371
None,
351372
memory_map,
@@ -688,6 +709,7 @@ mod test {
688709
Client::new(),
689710
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
690711
check_method,
712+
RangeRequestUrlSource::Response,
691713
HeaderMap::default(),
692714
)
693715
.await
@@ -728,7 +750,7 @@ mod test {
728750
);
729751

730752
// Prefetch the data for the metadata.json file
731-
let entry = reader.file().entries().get(0).unwrap();
753+
let entry = reader.file().entries().first().unwrap();
732754
let offset = entry.header_offset();
733755
// Get the size of the entry plus the header + size of the filename. We should also actually
734756
// include bytes for the extra fields but we don't have that information.
@@ -783,6 +805,57 @@ mod test {
783805
Client::new(),
784806
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
785807
check_method,
808+
RangeRequestUrlSource::Response,
809+
HeaderMap::default(),
810+
)
811+
.await
812+
.expect("bla");
813+
814+
// Also open a simple file reader
815+
let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda"))
816+
.await
817+
.unwrap();
818+
819+
// Read until the end and make sure that the contents matches
820+
let mut range_read = vec![0; 64 * 1024];
821+
let mut file_read = vec![0; 64 * 1024];
822+
loop {
823+
// Read with the async reader
824+
let range_read_bytes = range.read(&mut range_read).await.unwrap();
825+
826+
// Read directly from the file
827+
let file_read_bytes = file
828+
.read_exact(&mut file_read[0..range_read_bytes])
829+
.await
830+
.unwrap();
831+
832+
assert_eq!(range_read_bytes, file_read_bytes);
833+
assert_eq!(
834+
range_read[0..range_read_bytes],
835+
file_read[0..file_read_bytes]
836+
);
837+
838+
if file_read_bytes == 0 && range_read_bytes == 0 {
839+
break;
840+
}
841+
}
842+
}
843+
844+
#[rstest]
845+
#[case(RangeRequestUrlSource::Request)]
846+
#[case(RangeRequestUrlSource::Response)]
847+
#[tokio::test]
848+
async fn async_range_reader_url_source(#[case] url_source: RangeRequestUrlSource) {
849+
// Spawn a static file server
850+
let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data");
851+
let server = StaticDirectoryServer::new(&path);
852+
853+
// Construct an AsyncRangeReader
854+
let (mut range, _) = AsyncHttpRangeReader::new(
855+
Client::new(),
856+
server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(),
857+
CheckSupportMethod::Head,
858+
url_source,
786859
HeaderMap::default(),
787860
)
788861
.await
@@ -825,6 +898,7 @@ mod test {
825898
Client::new(),
826899
server.url().join("not-found").unwrap(),
827900
CheckSupportMethod::Head,
901+
RangeRequestUrlSource::Response,
828902
HeaderMap::default(),
829903
)
830904
.await

0 commit comments

Comments
 (0)