Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
// SPDX-License-Identifier: MIT OR Apache-2.0

use core::cmp::Ordering;
use core::mem;

use widestring::U16CStr;

use crate::error::{NtStringError, Result};

/// Generic memory layout unified for `ANSI_STRING`, `OEM_STRING`, `UNICODE_STRING`,
/// in their mutable and immutable versions.
Expand All @@ -16,6 +21,58 @@ pub(crate) struct RawNtString<T> {
pub(crate) buffer: T,
}

pub(crate) fn check_from_u16(buffer: &[u16]) -> Result<u16> {
let elements = buffer.len();
let length_usize = elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
let length = u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?;
Ok(length)
}

pub(crate) fn check_from_u16_until_nul(buffer: &[u16]) -> Result<(u16, u16)> {
let length;
let maximum_length;

match buffer.iter().position(|x| *x == 0) {
Some(nul_pos) => {
// Include the terminating NUL character in `maximum_length` ...
let maximum_elements = nul_pos
.checked_add(1)
.ok_or(NtStringError::BufferSizeExceedsU16)?;
let maximum_length_usize = maximum_elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
maximum_length = u16::try_from(maximum_length_usize)
.map_err(|_| NtStringError::BufferSizeExceedsU16)?;

// ... but not in `length`
length = maximum_length - mem::size_of::<u16>() as u16;
}
None => return Err(NtStringError::NulNotFound),
};

Ok((length, maximum_length))
}

pub(crate) fn check_from_u16_cstr(u16cstr: &U16CStr) -> Result<(u16, u16)> {
let buffer = u16cstr.as_slice_with_nul();

// Include the terminating NUL character in `maximum_length` ...
let maximum_length_in_elements = buffer.len();
let maximum_length_in_bytes = maximum_length_in_elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
let maximum_length =
u16::try_from(maximum_length_in_bytes).map_err(|_| NtStringError::BufferSizeExceedsU16)?;

// ... but not in `length`
debug_assert!(maximum_length >= mem::size_of::<u16>() as u16);
let length = maximum_length - mem::size_of::<u16>() as u16;

Ok((length, maximum_length))
}

/// Compare any two `u16` iterators and return an [`Ordering`] value.
///
/// Can be used to implement `cmp`/`partial_cmp` and `eq`/`partial_eq`.
Expand Down
48 changes: 7 additions & 41 deletions src/unicode_string/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use core::{fmt, mem, slice};
use widestring::{U16CStr, U16Str};

use crate::error::{NtStringError, Result};
use crate::helpers::{cmp_iter, RawNtString};
use crate::helpers::{
check_from_u16, check_from_u16_cstr, check_from_u16_until_nul, cmp_iter, RawNtString,
};

use super::iter::{Chars, CharsLossy};

Expand Down Expand Up @@ -158,12 +160,7 @@ impl<'a> NtUnicodeStr<'a> {
///
/// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul
pub fn try_from_u16(buffer: &'a [u16]) -> Result<Self> {
let elements = buffer.len();
let length_usize = elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
let length =
u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?;
let length = check_from_u16(buffer)?;

Ok(Self {
raw: RawNtString {
Expand Down Expand Up @@ -191,26 +188,7 @@ impl<'a> NtUnicodeStr<'a> {
///
/// [`try_from_u16`]: Self::try_from_u16
pub fn try_from_u16_until_nul(buffer: &'a [u16]) -> Result<Self> {
let length;
let maximum_length;

match buffer.iter().position(|x| *x == 0) {
Some(nul_pos) => {
// Include the terminating NUL character in `maximum_length` ...
let maximum_elements = nul_pos
.checked_add(1)
.ok_or(NtStringError::BufferSizeExceedsU16)?;
let maximum_length_usize = maximum_elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
maximum_length = u16::try_from(maximum_length_usize)
.map_err(|_| NtStringError::BufferSizeExceedsU16)?;

// ... but not in `length`
length = maximum_length - mem::size_of::<u16>() as u16;
}
None => return Err(NtStringError::NulNotFound),
};
let (length, maximum_length) = check_from_u16_until_nul(buffer)?;

Ok(Self {
raw: RawNtString {
Expand Down Expand Up @@ -314,25 +292,13 @@ impl<'a> TryFrom<&'a U16CStr> for NtUnicodeStr<'a> {
/// The internal buffer will be NUL-terminated.
/// See the [module-level documentation](super) for the implications of that.
fn try_from(value: &'a U16CStr) -> Result<Self> {
let buffer = value.as_slice_with_nul();

// Include the terminating NUL character in `maximum_length` ...
let maximum_length_in_elements = buffer.len();
let maximum_length_in_bytes = maximum_length_in_elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
let maximum_length = u16::try_from(maximum_length_in_bytes)
.map_err(|_| NtStringError::BufferSizeExceedsU16)?;

// ... but not in `length`
debug_assert!(maximum_length >= mem::size_of::<u16>() as u16);
let length = maximum_length - mem::size_of::<u16>() as u16;
let (length, maximum_length) = check_from_u16_cstr(value)?;

Ok(Self {
raw: RawNtString {
length,
maximum_length,
buffer: buffer.as_ptr(),
buffer: value.as_ptr(),
},
_lifetime: PhantomData,
})
Expand Down
50 changes: 28 additions & 22 deletions src/unicode_string/strmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use core::{fmt, mem, slice};
use widestring::{U16CStr, U16Str};

use crate::error::Result;
use crate::helpers::RawNtString;
use crate::helpers::{check_from_u16, check_from_u16_cstr, check_from_u16_until_nul, RawNtString};
use crate::NtStringError;

use super::{impl_eq, impl_partial_cmp, NtUnicodeStr};
Expand Down Expand Up @@ -122,14 +122,16 @@ impl<'a> NtUnicodeStrMut<'a> {
///
/// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul
pub fn try_from_u16(buffer: &mut [u16]) -> Result<Self> {
let unicode_str = NtUnicodeStr::try_from_u16(buffer)?;
let length = check_from_u16(buffer)?;

// SAFETY: `unicode_str` was created from a mutable `buffer` and
// `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout,
// so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`.
let unicode_str_mut = unsafe { mem::transmute(unicode_str) };

Ok(unicode_str_mut)
Ok(Self {
raw: RawNtString {
length,
maximum_length: length,
buffer: buffer.as_mut_ptr(),
},
_lifetime: PhantomData,
})
}

/// Creates an [`NtUnicodeStrMut`] from an existing [`u16`] string buffer that contains at least one NUL character.
Expand All @@ -148,14 +150,16 @@ impl<'a> NtUnicodeStrMut<'a> {
///
/// [`try_from_u16`]: Self::try_from_u16
pub fn try_from_u16_until_nul(buffer: &mut [u16]) -> Result<Self> {
let unicode_str = NtUnicodeStr::try_from_u16_until_nul(buffer)?;
let (length, maximum_length) = check_from_u16_until_nul(buffer)?;

// SAFETY: `unicode_str` was created from a mutable `buffer` and
// `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout,
// so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`.
let unicode_str_mut = unsafe { mem::transmute(unicode_str) };

Ok(unicode_str_mut)
Ok(Self {
raw: RawNtString {
length,
maximum_length,
buffer: buffer.as_mut_ptr(),
},
_lifetime: PhantomData,
})
}
}

Expand Down Expand Up @@ -192,14 +196,16 @@ impl<'a> TryFrom<&'a mut U16CStr> for NtUnicodeStrMut<'a> {
/// The internal buffer will be NUL-terminated.
/// See the [module-level documentation](super) for the implications of that.
fn try_from(value: &'a mut U16CStr) -> Result<Self> {
let unicode_str = NtUnicodeStr::try_from(&*value)?;
let (length, maximum_length) = check_from_u16_cstr(value)?;

// SAFETY: `unicode_str` was created from a mutable `value` and
// `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout,
// so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`.
let unicode_str_mut = unsafe { mem::transmute(unicode_str) };

Ok(unicode_str_mut)
Ok(Self {
raw: RawNtString {
length,
maximum_length,
buffer: value.as_mut_ptr(),
},
_lifetime: PhantomData,
})
}
}

Expand Down