From 8d4d3b4abe5d637e9cca48d21a8897238b12732a Mon Sep 17 00:00:00 2001 From: Micha Kalfon Date: Mon, 27 Apr 2026 14:03:03 +0300 Subject: [PATCH] Fix UB in NtUnicodeStrMut::try_from_u16() Avoid transmuting an NtUnicodeStr into a NtUnicodeStrMut which is assigned a *const instead of a *mut for its buffer field. --- src/helpers.rs | 57 ++++++++++++++++++++++++++++++++++++ src/unicode_string/str.rs | 48 +++++------------------------- src/unicode_string/strmut.rs | 50 +++++++++++++++++-------------- 3 files changed, 92 insertions(+), 63 deletions(-) diff --git a/src/helpers.rs b/src/helpers.rs index 956eda4..0b61fde 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -2,6 +2,11 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use core::cmp::Ordering; +use core::mem; + +use widestring::U16CStr; + +use crate::error::{NtStringError, Result}; /// Generic memory layout unified for `ANSI_STRING`, `OEM_STRING`, `UNICODE_STRING`, /// in their mutable and immutable versions. @@ -16,6 +21,58 @@ pub(crate) struct RawNtString { pub(crate) buffer: T, } +pub(crate) fn check_from_u16(buffer: &[u16]) -> Result { + let elements = buffer.len(); + let length_usize = elements + .checked_mul(mem::size_of::()) + .ok_or(NtStringError::BufferSizeExceedsU16)?; + let length = u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?; + Ok(length) +} + +pub(crate) fn check_from_u16_until_nul(buffer: &[u16]) -> Result<(u16, u16)> { + let length; + let maximum_length; + + match buffer.iter().position(|x| *x == 0) { + Some(nul_pos) => { + // Include the terminating NUL character in `maximum_length` ... + let maximum_elements = nul_pos + .checked_add(1) + .ok_or(NtStringError::BufferSizeExceedsU16)?; + let maximum_length_usize = maximum_elements + .checked_mul(mem::size_of::()) + .ok_or(NtStringError::BufferSizeExceedsU16)?; + maximum_length = u16::try_from(maximum_length_usize) + .map_err(|_| NtStringError::BufferSizeExceedsU16)?; + + // ... but not in `length` + length = maximum_length - mem::size_of::() as u16; + } + None => return Err(NtStringError::NulNotFound), + }; + + Ok((length, maximum_length)) +} + +pub(crate) fn check_from_u16_cstr(u16cstr: &U16CStr) -> Result<(u16, u16)> { + let buffer = u16cstr.as_slice_with_nul(); + + // Include the terminating NUL character in `maximum_length` ... + let maximum_length_in_elements = buffer.len(); + let maximum_length_in_bytes = maximum_length_in_elements + .checked_mul(mem::size_of::()) + .ok_or(NtStringError::BufferSizeExceedsU16)?; + let maximum_length = + u16::try_from(maximum_length_in_bytes).map_err(|_| NtStringError::BufferSizeExceedsU16)?; + + // ... but not in `length` + debug_assert!(maximum_length >= mem::size_of::() as u16); + let length = maximum_length - mem::size_of::() as u16; + + Ok((length, maximum_length)) +} + /// Compare any two `u16` iterators and return an [`Ordering`] value. /// /// Can be used to implement `cmp`/`partial_cmp` and `eq`/`partial_eq`. diff --git a/src/unicode_string/str.rs b/src/unicode_string/str.rs index 0ff25d2..9c02cb6 100644 --- a/src/unicode_string/str.rs +++ b/src/unicode_string/str.rs @@ -10,7 +10,9 @@ use core::{fmt, mem, slice}; use widestring::{U16CStr, U16Str}; use crate::error::{NtStringError, Result}; -use crate::helpers::{cmp_iter, RawNtString}; +use crate::helpers::{ + check_from_u16, check_from_u16_cstr, check_from_u16_until_nul, cmp_iter, RawNtString, +}; use super::iter::{Chars, CharsLossy}; @@ -158,12 +160,7 @@ impl<'a> NtUnicodeStr<'a> { /// /// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul pub fn try_from_u16(buffer: &'a [u16]) -> Result { - let elements = buffer.len(); - let length_usize = elements - .checked_mul(mem::size_of::()) - .ok_or(NtStringError::BufferSizeExceedsU16)?; - let length = - u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?; + let length = check_from_u16(buffer)?; Ok(Self { raw: RawNtString { @@ -191,26 +188,7 @@ impl<'a> NtUnicodeStr<'a> { /// /// [`try_from_u16`]: Self::try_from_u16 pub fn try_from_u16_until_nul(buffer: &'a [u16]) -> Result { - let length; - let maximum_length; - - match buffer.iter().position(|x| *x == 0) { - Some(nul_pos) => { - // Include the terminating NUL character in `maximum_length` ... - let maximum_elements = nul_pos - .checked_add(1) - .ok_or(NtStringError::BufferSizeExceedsU16)?; - let maximum_length_usize = maximum_elements - .checked_mul(mem::size_of::()) - .ok_or(NtStringError::BufferSizeExceedsU16)?; - maximum_length = u16::try_from(maximum_length_usize) - .map_err(|_| NtStringError::BufferSizeExceedsU16)?; - - // ... but not in `length` - length = maximum_length - mem::size_of::() as u16; - } - None => return Err(NtStringError::NulNotFound), - }; + let (length, maximum_length) = check_from_u16_until_nul(buffer)?; Ok(Self { raw: RawNtString { @@ -314,25 +292,13 @@ impl<'a> TryFrom<&'a U16CStr> for NtUnicodeStr<'a> { /// The internal buffer will be NUL-terminated. /// See the [module-level documentation](super) for the implications of that. fn try_from(value: &'a U16CStr) -> Result { - let buffer = value.as_slice_with_nul(); - - // Include the terminating NUL character in `maximum_length` ... - let maximum_length_in_elements = buffer.len(); - let maximum_length_in_bytes = maximum_length_in_elements - .checked_mul(mem::size_of::()) - .ok_or(NtStringError::BufferSizeExceedsU16)?; - let maximum_length = u16::try_from(maximum_length_in_bytes) - .map_err(|_| NtStringError::BufferSizeExceedsU16)?; - - // ... but not in `length` - debug_assert!(maximum_length >= mem::size_of::() as u16); - let length = maximum_length - mem::size_of::() as u16; + let (length, maximum_length) = check_from_u16_cstr(value)?; Ok(Self { raw: RawNtString { length, maximum_length, - buffer: buffer.as_ptr(), + buffer: value.as_ptr(), }, _lifetime: PhantomData, }) diff --git a/src/unicode_string/strmut.rs b/src/unicode_string/strmut.rs index b54c9b3..f2510d9 100644 --- a/src/unicode_string/strmut.rs +++ b/src/unicode_string/strmut.rs @@ -9,7 +9,7 @@ use core::{fmt, mem, slice}; use widestring::{U16CStr, U16Str}; use crate::error::Result; -use crate::helpers::RawNtString; +use crate::helpers::{check_from_u16, check_from_u16_cstr, check_from_u16_until_nul, RawNtString}; use crate::NtStringError; use super::{impl_eq, impl_partial_cmp, NtUnicodeStr}; @@ -122,14 +122,16 @@ impl<'a> NtUnicodeStrMut<'a> { /// /// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul pub fn try_from_u16(buffer: &mut [u16]) -> Result { - let unicode_str = NtUnicodeStr::try_from_u16(buffer)?; + let length = check_from_u16(buffer)?; - // SAFETY: `unicode_str` was created from a mutable `buffer` and - // `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout, - // so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`. - let unicode_str_mut = unsafe { mem::transmute(unicode_str) }; - - Ok(unicode_str_mut) + Ok(Self { + raw: RawNtString { + length, + maximum_length: length, + buffer: buffer.as_mut_ptr(), + }, + _lifetime: PhantomData, + }) } /// Creates an [`NtUnicodeStrMut`] from an existing [`u16`] string buffer that contains at least one NUL character. @@ -148,14 +150,16 @@ impl<'a> NtUnicodeStrMut<'a> { /// /// [`try_from_u16`]: Self::try_from_u16 pub fn try_from_u16_until_nul(buffer: &mut [u16]) -> Result { - let unicode_str = NtUnicodeStr::try_from_u16_until_nul(buffer)?; + let (length, maximum_length) = check_from_u16_until_nul(buffer)?; - // SAFETY: `unicode_str` was created from a mutable `buffer` and - // `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout, - // so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`. - let unicode_str_mut = unsafe { mem::transmute(unicode_str) }; - - Ok(unicode_str_mut) + Ok(Self { + raw: RawNtString { + length, + maximum_length, + buffer: buffer.as_mut_ptr(), + }, + _lifetime: PhantomData, + }) } } @@ -192,14 +196,16 @@ impl<'a> TryFrom<&'a mut U16CStr> for NtUnicodeStrMut<'a> { /// The internal buffer will be NUL-terminated. /// See the [module-level documentation](super) for the implications of that. fn try_from(value: &'a mut U16CStr) -> Result { - let unicode_str = NtUnicodeStr::try_from(&*value)?; + let (length, maximum_length) = check_from_u16_cstr(value)?; - // SAFETY: `unicode_str` was created from a mutable `value` and - // `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout, - // so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`. - let unicode_str_mut = unsafe { mem::transmute(unicode_str) }; - - Ok(unicode_str_mut) + Ok(Self { + raw: RawNtString { + length, + maximum_length, + buffer: value.as_mut_ptr(), + }, + _lifetime: PhantomData, + }) } }