diff --git a/src/error.rs b/src/error.rs index 8b41f9c93d..0d8b78bb9c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -92,7 +92,7 @@ pub(super) enum Parse { Version, #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] VersionH2, - Uri, + Uri(Option), #[cfg(all(feature = "http1", feature = "server"))] UriTooLong, #[cfg(feature = "http1")] @@ -459,7 +459,7 @@ impl Error { Kind::Parse(Parse::Version) => "invalid HTTP version parsed", #[cfg(all(any(feature = "client", feature = "server"), feature = "http1"))] Kind::Parse(Parse::VersionH2) => "invalid HTTP version parsed (found HTTP2 preface)", - Kind::Parse(Parse::Uri) => "invalid URI", + Kind::Parse(Parse::Uri(_)) => "invalid URI", #[cfg(all(feature = "http1", feature = "server"))] Kind::Parse(Parse::UriTooLong) => "URI too long", #[cfg(feature = "http1")] @@ -582,7 +582,10 @@ impl StdError for Error { #[doc(hidden)] impl From for Error { fn from(err: Parse) -> Error { - Error::new(Kind::Parse(err)) + match err { + Parse::Uri(Some(cause)) => Error::new(Kind::Parse(Parse::Uri(None))).with(cause), + other => Error::new(Kind::Parse(other)), + } } } @@ -632,14 +635,14 @@ impl From for Parse { } impl From for Parse { - fn from(_: http::uri::InvalidUri) -> Parse { - Parse::Uri + fn from(err: http::uri::InvalidUri) -> Parse { + Parse::Uri(Some(Box::new(err))) } } impl From for Parse { - fn from(_: http::uri::InvalidUriParts) -> Parse { - Parse::Uri + fn from(err: http::uri::InvalidUriParts) -> Parse { + Parse::Uri(Some(Box::new(err))) } } @@ -692,4 +695,27 @@ mod tests { let svc_err = Error::new_user_service(recvd); assert_eq!(svc_err.h2_reason(), h2::Reason::HTTP_1_1_REQUIRED); } + + #[test] + fn uri_error_preserves_source() { + use std::error::Error as _; + + // Parse an invalid URI through the http crate + let invalid: std::result::Result = "dangling whitespace ".parse(); + let uri_err: http::uri::InvalidUri = invalid.unwrap_err(); + + // Convert through the same path hyper uses: InvalidUri -> Parse -> Error + let parse: Parse = Parse::from(uri_err); + let error: Error = Error::from(parse); + + // The error should have a source + assert!(error.source().is_some(), "URI error should preserve source"); + + // The source should be the original InvalidUri + let source = error.source().unwrap(); + assert!( + source.downcast_ref::().is_some(), + "source should be http::uri::InvalidUri" + ); + } } diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index f92092e5a9..aae7095443 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -199,7 +199,7 @@ impl Http1Transaction for Server { Parse::Method } else { debug_assert!(req.path.is_none()); - Parse::Uri + Parse::Uri(None) } }) } @@ -467,7 +467,7 @@ impl Http1Transaction for Server { let status = match *err.kind() { Kind::Parse(Parse::Method) | Kind::Parse(Parse::Header(_)) - | Kind::Parse(Parse::Uri) + | Kind::Parse(Parse::Uri(_)) | Kind::Parse(Parse::Version) => StatusCode::BAD_REQUEST, Kind::Parse(Parse::TooLarge) => StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE, Kind::Parse(Parse::UriTooLong) => StatusCode::URI_TOO_LONG,