diff --git a/tls_codec/benches/quic_vec.rs b/tls_codec/benches/quic_vec.rs index 9181bbc07..f6a63bdad 100644 --- a/tls_codec/benches/quic_vec.rs +++ b/tls_codec/benches/quic_vec.rs @@ -57,7 +57,7 @@ fn byte_slice(c: &mut Criterion) { c.bench_function("TLS Serialize VL Byte Slice", |b| { b.iter_batched_ref( || (vec![77u8; N], Vec::with_capacity(8 + N)), - |(long_vec, buf)| VLByteSlice(long_vec).tls_serialize(buf).unwrap(), + |(long_vec, buf)| Serialize::tls_serialize(&VLByteSlice(long_vec), buf).unwrap(), BatchSize::SmallInput, ) }); diff --git a/tls_codec/src/lib.rs b/tls_codec/src/lib.rs index cdd6cc155..224f4bc40 100644 --- a/tls_codec/src/lib.rs +++ b/tls_codec/src/lib.rs @@ -38,6 +38,7 @@ use std::io::{Read, Write}; mod arrays; mod primitives; mod quic_vec; +mod string; mod tls_vec; mod varint; diff --git a/tls_codec/src/quic_vec.rs b/tls_codec/src/quic_vec.rs index d477dd9fb..1f046c9c0 100644 --- a/tls_codec/src/quic_vec.rs +++ b/tls_codec/src/quic_vec.rs @@ -539,6 +539,29 @@ impl Size for VLByteSlice<'_> { } } +impl SerializeBytes for ContentLength { + fn tls_serialize(&self) -> Result, Error> { + SerializeBytes::tls_serialize(&self.0) + } +} + +impl SerializeBytes for VLByteSlice<'_> { + fn tls_serialize(&self) -> Result, Error> { + // Get the byte length of the content, make sure it's not too + // large and write it out. + let content_length = self.0.len(); + + let mut len_bytes = + SerializeBytes::tls_serialize(&ContentLength::from_usize(content_length)?)?; + + let mut out = alloc::vec::Vec::with_capacity(content_length + len_bytes.len()); + out.append(&mut len_bytes); + out.extend(self.0); + + Ok(out) + } +} + #[cfg(feature = "std")] pub mod rw { use super::*; @@ -554,7 +577,7 @@ pub mod rw { impl Serialize for ContentLength { #[inline(always)] fn tls_serialize(&self, writer: &mut W) -> Result { - self.0.tls_serialize(writer) + Serialize::tls_serialize(&self.0, writer) } } @@ -598,7 +621,7 @@ pub mod rw { writer: &mut W, content_length: usize, ) -> Result { - ContentLength::from_usize(content_length)?.tls_serialize(writer) + Serialize::tls_serialize(&ContentLength::from_usize(content_length)?, writer) } impl Serialize for Vec { @@ -654,7 +677,8 @@ mod rw_bytes { // large and write it out. let content_length = bytes.len(); - let len_len = ContentLength::from_usize(content_length)?.tls_serialize(writer)?; + let len_len = + Serialize::tls_serialize(&ContentLength::from_usize(content_length)?, writer)?; // Now serialize the elements writer.write_all(bytes)?; diff --git a/tls_codec/src/string.rs b/tls_codec/src/string.rs new file mode 100644 index 000000000..447b9ef22 --- /dev/null +++ b/tls_codec/src/string.rs @@ -0,0 +1,231 @@ +//! This module implements de/serialization for String by storing the UTF-8 representation in a +//! VLByteVec, i.e. a byte vec with a varint Length. + +use alloc::string::String; + +use crate::{DeserializeBytes, SerializeBytes, Size, VLByteSlice, VLByteVec}; + +impl Size for String { + fn tls_serialized_len(&self) -> usize { + self.as_bytes().tls_serialized_len() + } +} + +impl Size for &str { + fn tls_serialized_len(&self) -> usize { + self.as_bytes().tls_serialized_len() + } +} + +impl SerializeBytes for String { + fn tls_serialize(&self) -> Result, crate::Error> { + SerializeBytes::tls_serialize(&VLByteSlice(self.as_bytes())) + } +} + +impl SerializeBytes for &str { + fn tls_serialize(&self) -> Result, crate::Error> { + SerializeBytes::tls_serialize(&self.as_bytes()) + } +} + +impl DeserializeBytes for String { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), crate::Error> + where + Self: Sized, + { + let (bytes, rest) = VLByteVec::tls_deserialize_bytes(bytes)?; + let text = String::from_utf8(bytes.into()) + .map_err(|err| crate::Error::DecodingError(format!("invalid utf8: {err}")))?; + + Ok((text, rest)) + } +} + +#[cfg(feature = "std")] +mod std_only { + use super::*; + use crate::{Deserialize, Serialize}; + + impl Serialize for String { + fn tls_serialize(&self, writer: &mut W) -> Result { + Serialize::tls_serialize(&VLByteSlice(self.as_bytes()), writer) + } + } + + impl Serialize for &str { + fn tls_serialize(&self, writer: &mut W) -> Result { + Serialize::tls_serialize(&self.as_bytes(), writer) + } + } + + impl Deserialize for String { + fn tls_deserialize(bytes: &mut R) -> Result + where + Self: Sized, + { + let bytes = VLByteVec::tls_deserialize(bytes)?; + String::from_utf8(bytes.into()) + .map_err(|err| crate::Error::DecodingError(format!("invalid utf8: {err}"))) + } + } +} + +#[cfg(all(test, feature = "std"))] +mod tests_with_std { + use crate::{Deserialize, Serialize, Size}; + use alloc::string::String; + + #[test] + fn serialize_multibyte_utf8_string() { + // U+00FC = "ü", encoded as 2 bytes in UTF-8: [0xC3, 0xBC] + let s = String::from("ü"); + let buf = s.tls_serialize_detached().unwrap(); + assert_eq!(buf, [2, 0xC3, 0xBC]); + assert_eq!(s.tls_serialized_len(), 3); + } + + #[test] + fn serialize_empty_string() { + let s = String::new(); + let buf = s.tls_serialize_detached().unwrap(); + assert_eq!(buf, [0]); + assert_eq!(s.tls_serialized_len(), 1); + } + + #[test] + fn serialize_hello_string() { + let s = String::from("hello"); + let buf = s.tls_serialize_detached().unwrap(); + // length prefix (5) + b"hello" + assert_eq!(buf, [5, b'h', b'e', b'l', b'l', b'o']); + assert_eq!(s.tls_serialized_len(), 6); + } + + #[test] + fn roundtrip_deserialize() { + let original = String::from("roundtrip test"); + let buf = original.tls_serialize_detached().unwrap(); + let deserialized = String::tls_deserialize_exact(&buf).unwrap(); + assert_eq!(original, deserialized); + } + + #[test] + fn roundtrip_deserialize_longstring() { + let original = String::from_utf8(vec![0x30u8; 300]).unwrap(); + let buf = original.tls_serialize_detached().unwrap(); + let deserialized = String::tls_deserialize_exact(&buf).unwrap(); + assert_eq!(original, deserialized); + } + + #[test] + fn roundtrip_deserialize_empty() { + let original = String::new(); + let buf = original.tls_serialize_detached().unwrap(); + let deserialized = String::tls_deserialize_exact(&buf).unwrap(); + assert_eq!(original, deserialized); + } + + #[test] + fn deserialize_invalid_utf8() { + // length prefix 2 + two bytes that are not valid UTF-8 + let buf: &[u8] = &[2, 0xFF, 0xFE]; + let err = String::tls_deserialize_exact(buf).unwrap_err(); + assert!(matches!(err, crate::Error::DecodingError(msg) if msg.contains("invalid utf8"))); + } +} +#[cfg(test)] +mod tests { + use alloc::string::String; + + #[cfg(feature = "std")] + use crate::Serialize; + + use crate::{DeserializeBytes, SerializeBytes, Size}; + + #[test] + fn serialize_empty_str() { + let s = ""; + + #[cfg(feature = "std")] + { + let mut buf = [0u8; 1]; + Serialize::tls_serialize(&s, &mut buf.as_mut_slice()).unwrap(); + assert_eq!(buf, [0]); + assert_eq!(s.tls_serialized_len(), 1); + } + + let buf = SerializeBytes::tls_serialize(&s).unwrap(); + assert_eq!(buf, [0]); + assert_eq!(s.tls_serialized_len(), 1); + } + + #[test] + fn serialize_hello_str() { + let s = "hello"; + #[cfg(feature = "std")] + { + let mut buf = [0u8; 6]; + Serialize::tls_serialize(&s, &mut buf.as_mut_slice()).unwrap(); + // length prefix (5) + b"hello" + assert_eq!(buf, [5, b'h', b'e', b'l', b'l', b'o']); + assert_eq!(s.tls_serialized_len(), 6); + } + + let buf = SerializeBytes::tls_serialize(&s).unwrap(); + // length prefix (5) + b"hello" + assert_eq!(buf, [5, b'h', b'e', b'l', b'l', b'o']); + assert_eq!(s.tls_serialized_len(), 6); + } + + #[test] + fn serialize_multibyte_utf8_str() { + // U+00FC = "ü", encoded as 2 bytes in UTF-8: [0xC3, 0xBC] + let s = "ü"; + #[cfg(feature = "std")] + { + let mut buf = [0u8; 3]; + Serialize::tls_serialize(&s, &mut buf.as_mut_slice()).unwrap(); + assert_eq!(buf, [2, 0xC3, 0xBC]); + assert_eq!(s.tls_serialized_len(), 3); + } + + let buf = SerializeBytes::tls_serialize(&s).unwrap(); + assert_eq!(buf, [2, 0xC3, 0xBC]); + assert_eq!(s.tls_serialized_len(), 3); + } + + #[test] + fn deserialize_bytes_hello() { + let input = [5, b'h', b'e', b'l', b'l', b'o']; + let (s, rest) = String::tls_deserialize_bytes(&input).unwrap(); + assert_eq!(s, "hello"); + assert!(rest.is_empty()); + assert_eq!(s.tls_serialized_len(), 6); + } + + #[test] + fn deserialize_bytes_with_trailing_data() { + // "hi" (length 2) followed by extra byte 0x99 + let input = [2, b'h', b'i', 0x99]; + let (s, rest) = String::tls_deserialize_bytes(&input).unwrap(); + assert_eq!(s, "hi"); + assert_eq!(rest, [0x99]); + } + + #[test] + fn deserialize_bytes_invalid_utf8() { + // length prefix 3 + 3 bytes that form an invalid UTF-8 sequence + let input = [3, 0xED, 0xA0, 0x80]; // surrogates are invalid in UTF-8 + let err = String::tls_deserialize_exact_bytes(&input).unwrap_err(); + assert!(matches!(err, crate::Error::DecodingError(msg) if msg.contains("invalid utf8"))); + } + + #[test] + fn deserialize_bytes_empty_string() { + let input = [0]; + let (s, rest) = String::tls_deserialize_bytes(&input).unwrap(); + assert_eq!(s, ""); + assert!(rest.is_empty()); + } +} diff --git a/tls_codec/src/varint.rs b/tls_codec/src/varint.rs index b2309fc4d..c289c6cff 100644 --- a/tls_codec/src/varint.rs +++ b/tls_codec/src/varint.rs @@ -1,4 +1,4 @@ -use crate::{Deserialize, DeserializeBytes, Error, Serialize, Size}; +use crate::{Deserialize, DeserializeBytes, Error, Serialize, SerializeBytes, Size}; /// Variable-length encoded unsigned integer as defined in [RFC 9000]. /// @@ -168,6 +168,16 @@ impl Serialize for TlsVarInt { } } +impl SerializeBytes for TlsVarInt { + #[inline] + fn tls_serialize(&self) -> Result, Error> { + let mut bytes = alloc::vec![0u8; 8]; + let len = self.write_bytes(&mut bytes)?; + bytes.truncate(len); + Ok(bytes) + } +} + impl Size for TlsVarInt { #[inline] fn tls_serialized_len(&self) -> usize { @@ -237,10 +247,11 @@ mod tests { for (value, len, bytes) in TESTS { let mut buf = Vec::new(); - let written = TlsVarInt::try_from(value) - .expect("value too large") - .tls_serialize(&mut buf) - .expect("tls serialize failed"); + let written = Serialize::tls_serialize( + &TlsVarInt::try_from(value).expect("value too large"), + &mut buf, + ) + .expect("tls serialize failed"); assert_eq!(written, len, "{value}"); assert_eq!(buf.len(), len, "{value}"); assert_eq!(&buf[..], bytes, "{value}");