diff --git a/src/unicode_string/str.rs b/src/unicode_string/str.rs index 0ff25d2..36388fe 100644 --- a/src/unicode_string/str.rs +++ b/src/unicode_string/str.rs @@ -1,4 +1,4 @@ -// Copyright 2023 Colin Finck +// Copyright 2023-2026 Colin Finck // SPDX-License-Identifier: MIT OR Apache-2.0 use core::cmp::Ordering; @@ -158,12 +158,7 @@ impl<'a> NtUnicodeStr<'a> { /// /// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul pub fn try_from_u16(buffer: &'a [u16]) -> Result { - let elements = buffer.len(); - let length_usize = elements - .checked_mul(mem::size_of::()) - .ok_or(NtStringError::BufferSizeExceedsU16)?; - let length = - u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?; + let length = Self::try_length_from_u16(buffer)?; Ok(Self { raw: RawNtString { @@ -191,9 +186,48 @@ impl<'a> NtUnicodeStr<'a> { /// /// [`try_from_u16`]: Self::try_from_u16 pub fn try_from_u16_until_nul(buffer: &'a [u16]) -> Result { - let length; - let maximum_length; + let (length, maximum_length) = Self::try_length_from_u16_until_nul(buffer)?; + + Ok(Self { + raw: RawNtString { + length, + maximum_length, + buffer: buffer.as_ptr(), + }, + _lifetime: PhantomData, + }) + } + + pub(crate) fn try_length_from_u16(buffer: &[u16]) -> Result { + let elements = buffer.len(); + let length_usize = elements + .checked_mul(mem::size_of::()) + .ok_or(NtStringError::BufferSizeExceedsU16)?; + let length = + u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?; + + Ok(length) + } + + pub(crate) fn try_length_from_u16_cstr(u16cstr: &U16CStr) -> Result<(u16, u16)> { + let buffer = u16cstr.as_slice_with_nul(); + + // Include the terminating NUL character in `maximum_length` ... + let maximum_length_in_elements = buffer.len(); + let maximum_length_in_bytes = maximum_length_in_elements + .checked_mul(mem::size_of::()) + .ok_or(NtStringError::BufferSizeExceedsU16)?; + let maximum_length = u16::try_from(maximum_length_in_bytes) + .map_err(|_| NtStringError::BufferSizeExceedsU16)?; + + // ... but not in `length` + debug_assert!(maximum_length >= mem::size_of::() as u16); + let length = maximum_length - mem::size_of::() as u16; + + Ok((length, maximum_length)) + } + pub(crate) fn try_length_from_u16_until_nul(buffer: &[u16]) -> Result<(u16, u16)> { match buffer.iter().position(|x| *x == 0) { Some(nul_pos) => { // Include the terminating NUL character in `maximum_length` ... @@ -203,23 +237,16 @@ impl<'a> NtUnicodeStr<'a> { let maximum_length_usize = maximum_elements .checked_mul(mem::size_of::()) .ok_or(NtStringError::BufferSizeExceedsU16)?; - maximum_length = u16::try_from(maximum_length_usize) + let maximum_length = u16::try_from(maximum_length_usize) .map_err(|_| NtStringError::BufferSizeExceedsU16)?; // ... but not in `length` - length = maximum_length - mem::size_of::() as u16; - } - None => return Err(NtStringError::NulNotFound), - }; + let length = maximum_length - mem::size_of::() as u16; - Ok(Self { - raw: RawNtString { - length, - maximum_length, - buffer: buffer.as_ptr(), - }, - _lifetime: PhantomData, - }) + Ok((length, maximum_length)) + } + None => Err(NtStringError::NulNotFound), + } } pub(crate) fn u16_iter(&'a self) -> Copied> { @@ -314,25 +341,13 @@ impl<'a> TryFrom<&'a U16CStr> for NtUnicodeStr<'a> { /// The internal buffer will be NUL-terminated. /// See the [module-level documentation](super) for the implications of that. fn try_from(value: &'a U16CStr) -> Result { - let buffer = value.as_slice_with_nul(); - - // Include the terminating NUL character in `maximum_length` ... - let maximum_length_in_elements = buffer.len(); - let maximum_length_in_bytes = maximum_length_in_elements - .checked_mul(mem::size_of::()) - .ok_or(NtStringError::BufferSizeExceedsU16)?; - let maximum_length = u16::try_from(maximum_length_in_bytes) - .map_err(|_| NtStringError::BufferSizeExceedsU16)?; - - // ... but not in `length` - debug_assert!(maximum_length >= mem::size_of::() as u16); - let length = maximum_length - mem::size_of::() as u16; + let (length, maximum_length) = Self::try_length_from_u16_cstr(value)?; Ok(Self { raw: RawNtString { length, maximum_length, - buffer: buffer.as_ptr(), + buffer: value.as_ptr(), }, _lifetime: PhantomData, }) diff --git a/src/unicode_string/strmut.rs b/src/unicode_string/strmut.rs index b54c9b3..aae43d5 100644 --- a/src/unicode_string/strmut.rs +++ b/src/unicode_string/strmut.rs @@ -1,4 +1,4 @@ -// Copyright 2023 Colin Finck +// Copyright 2023-2026 Colin Finck // SPDX-License-Identifier: MIT OR Apache-2.0 use core::cmp::Ordering; @@ -122,14 +122,16 @@ impl<'a> NtUnicodeStrMut<'a> { /// /// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul pub fn try_from_u16(buffer: &mut [u16]) -> Result { - let unicode_str = NtUnicodeStr::try_from_u16(buffer)?; + let length = NtUnicodeStr::try_length_from_u16(buffer)?; - // SAFETY: `unicode_str` was created from a mutable `buffer` and - // `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout, - // so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`. - let unicode_str_mut = unsafe { mem::transmute(unicode_str) }; - - Ok(unicode_str_mut) + Ok(Self { + raw: RawNtString { + length, + maximum_length: length, + buffer: buffer.as_mut_ptr(), + }, + _lifetime: PhantomData, + }) } /// Creates an [`NtUnicodeStrMut`] from an existing [`u16`] string buffer that contains at least one NUL character. @@ -148,14 +150,16 @@ impl<'a> NtUnicodeStrMut<'a> { /// /// [`try_from_u16`]: Self::try_from_u16 pub fn try_from_u16_until_nul(buffer: &mut [u16]) -> Result { - let unicode_str = NtUnicodeStr::try_from_u16_until_nul(buffer)?; + let (length, maximum_length) = NtUnicodeStr::try_length_from_u16_until_nul(buffer)?; - // SAFETY: `unicode_str` was created from a mutable `buffer` and - // `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout, - // so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`. - let unicode_str_mut = unsafe { mem::transmute(unicode_str) }; - - Ok(unicode_str_mut) + Ok(Self { + raw: RawNtString { + length, + maximum_length, + buffer: buffer.as_mut_ptr(), + }, + _lifetime: PhantomData, + }) } } @@ -192,14 +196,16 @@ impl<'a> TryFrom<&'a mut U16CStr> for NtUnicodeStrMut<'a> { /// The internal buffer will be NUL-terminated. /// See the [module-level documentation](super) for the implications of that. fn try_from(value: &'a mut U16CStr) -> Result { - let unicode_str = NtUnicodeStr::try_from(&*value)?; + let (length, maximum_length) = NtUnicodeStr::try_length_from_u16_cstr(value)?; - // SAFETY: `unicode_str` was created from a mutable `value` and - // `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout, - // so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`. - let unicode_str_mut = unsafe { mem::transmute(unicode_str) }; - - Ok(unicode_str_mut) + Ok(Self { + raw: RawNtString { + length, + maximum_length, + buffer: value.as_mut_ptr(), + }, + _lifetime: PhantomData, + }) } }