diff --git a/src/bipaddr.rs b/src/bipaddr.rs new file mode 100644 index 0000000..b05322e --- /dev/null +++ b/src/bipaddr.rs @@ -0,0 +1,160 @@ +//! Byte IpAddr which helps with the deserialization. + +use core::{ + borrow::{Borrow, BorrowMut}, + fmt, net, + ops::{Deref, DerefMut}, +}; + +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, +}; + +/// IpAddr serialize to and deserialize from big-endian bytes representation. +/// +/// Bencoded "strings" are not necessarily UTF-8 encoded values so if a field is +/// not guranteed to be a UTF-8 string, then you should use a `ByteString` or +/// another equivalent type. +/// +/// Due to a limitation within `serde` and Rust, a `IpAddr::V4` and `IpAddr::V6` will +/// serialize and deserialize as a list of individual byte elements. Serializing `IpAddr` +/// requires serialize enum which this bencode/library does not support yet. +/// +/// # Examples +/// +/// ```rust +/// use bt_bencode::ByteIpAddr; +/// use core::net; +/// +/// let v4_bytes = [1,2,3,4]; +/// let v6_bytes = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]; +/// let v4 = net::IpAddr::from(v4_bytes); +/// let v6 = net::IpAddr::from(v6_bytes); +/// let bip4 = ByteIpAddr::from(v4); +/// +/// +/// let encoded = bt_bencode::to_vec(&bip4)?; +/// assert_eq!(encoded, b"4:\x01\x02\x03\x04"); +/// +/// let decoded: ByteIpAddr = bt_bencode::from_slice(&encoded)?; +/// assert_eq!(decoded, v4.into()); +/// +/// let bip6 = ByteIpAddr::from(v6); +/// +/// let encoded = bt_bencode::to_vec(&bip6)?; +/// assert_eq!(encoded, b"16:\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10"); +/// +/// let decoded: ByteIpAddr = bt_bencode::from_slice(&encoded)?; +/// assert_eq!(decoded, v6.into()); +/// +/// // test invalid bytes +/// assert!(bt_bencode::from_slice::(b"5:\x01\x02\x03\x04\x05").is_err()); +/// +/// # Ok::<(), bt_bencode::Error>(()) +/// ``` +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ByteIpAddr(net::IpAddr); + +impl AsRef for ByteIpAddr { + fn as_ref(&self) -> &net::IpAddr { + &self.0 + } +} + +impl AsMut for ByteIpAddr { + fn as_mut(&mut self) -> &mut net::IpAddr { + &mut self.0 + } +} + +impl Borrow for ByteIpAddr { + fn borrow(&self) -> &net::IpAddr { + &self.0 + } +} + +impl BorrowMut for ByteIpAddr { + fn borrow_mut(&mut self) -> &mut net::IpAddr { + &mut self.0 + } +} + +impl fmt::Debug for ByteIpAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.0, f) + } +} + +impl Deref for ByteIpAddr { + type Target = net::IpAddr; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ByteIpAddr { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for ByteIpAddr +where + net::IpAddr: From, +{ + fn from(value: T) -> Self { + Self(net::IpAddr::from(value)) + } +} + +impl serde::Serialize for ByteIpAddr { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self.0 { + net::IpAddr::V4(ip) => serializer.serialize_bytes(&ip.octets()), + net::IpAddr::V6(ip) => serializer.serialize_bytes(&ip.octets()), + } + } +} + +struct IpAddrVisitor; + +impl<'de> Visitor<'de> for IpAddrVisitor { + type Value = ByteIpAddr; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("byte string of length 4(ipv4) or 16(ipv6)") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: de::Error, + { + match v.len() { + 4 => Ok(ByteIpAddr(net::IpAddr::V4(net::Ipv4Addr::from([ + v[0], v[1], v[2], v[3], + ])))), + 16 => Ok(ByteIpAddr(net::IpAddr::V6(net::Ipv6Addr::from([ + v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], v[8], v[9], v[10], v[11], v[12], + v[13], v[14], v[15], + ])))), + other => Err(de::Error::invalid_value( + de::Unexpected::Str(&format!("get byte string {v:02x?} of length {other}")), + &self, + )), + } + } +} + +impl<'de> Deserialize<'de> for ByteIpAddr { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_byte_buf(IpAddrVisitor) + } +} diff --git a/src/de.rs b/src/de.rs index e385439..515cd0e 100644 --- a/src/de.rs +++ b/src/de.rs @@ -263,6 +263,62 @@ impl<'a> Deserializer> { } } +impl<'de, R: Read<'de>> Deserializer { + // #[cfg(feature = "raw_value")] + fn deserialize_raw_value(&mut self, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + self.read.begin_raw_buffering(); + self.ignore_value()?; + self.read.end_raw_buffering(visitor) + } + + fn ignore_value(&mut self) -> Result<()> { + match self.parse_peek()? { + b'0'..=b'9' => { + self.buf.clear(); + self.read.parse_byte_str(&mut self.buf)?; + Ok(()) + } + b'i' => { + self.parse_next()?; + self.parse_integer()?; + Ok(()) + } + b'l' => { + self.parse_next()?; + loop { + if self.parse_peek()? == b'e' { + self.parse_next()?; + break; + } + self.ignore_value()?; + } + Ok(()) + } + b'd' => { + self.parse_next()?; + loop { + if self.parse_peek()? == b'e' { + self.parse_next()?; + break; + } + + self.buf.clear(); + self.read.parse_byte_str(&mut self.buf)?; + self.ignore_value()?; + } + Ok(()) + } + _ => Err(Error::new( + ErrorKind::ExpectedSomeValue, + self.read.byte_offset(), + )), + } + } +} + impl<'de, R: Read<'de>> de::Deserializer<'de> for &mut Deserializer { type Error = Error; @@ -393,10 +449,16 @@ impl<'de, R: Read<'de>> de::Deserializer<'de> for &mut Deserializer { } #[inline] - fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result where V: de::Visitor<'de>, { + // #[cfg(feature = "raw_value")] + if name == crate::raw::TOKEN { + return self.deserialize_raw_value(visitor); + } + + let _ = name; visitor.visit_newtype_struct(self) } @@ -481,10 +543,16 @@ where } #[inline] - fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result where V: de::Visitor<'de>, { + // #[cfg(feature = "raw_value")] + if name == crate::raw::TOKEN { + return self.de.deserialize_raw_value(visitor); + } + + let _ = name; visitor.visit_newtype_struct(self) } diff --git a/src/lib.rs b/src/lib.rs index 878f0c7..63817e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,6 +105,7 @@ extern crate alloc; #[macro_use] extern crate serde; +mod bipaddr; mod bstring; mod de; mod error; @@ -112,9 +113,14 @@ mod error; pub mod read; pub mod write; +// #[cfg(feature = "raw_value")] +mod raw; + mod ser; pub mod value; +#[doc(inline)] +pub use bipaddr::ByteIpAddr; #[doc(inline)] pub use bstring::ByteString; #[doc(inline)] @@ -134,3 +140,6 @@ pub use ser::{to_vec, Serializer}; #[doc(inline)] #[cfg(feature = "std")] pub use de::from_reader; + +// #[cfg(feature = "raw_value")] +pub use raw::RawValue; diff --git a/src/read.rs b/src/read.rs index f656f92..080f091 100644 --- a/src/read.rs +++ b/src/read.rs @@ -9,6 +9,9 @@ use alloc::vec::Vec; #[cfg(feature = "std")] use std::{io, vec::Vec}; +// #[cfg(feature = "raw_value")] +use {crate::raw::OwnedRawDeserializer, serde::de::Visitor}; + /// A reference to borrowed data. /// /// The variant determines if the slice comes from a long lived source (e.g. an @@ -125,6 +128,21 @@ pub trait Read<'a> { /// - malformatted input /// - end of file fn parse_raw_dict<'b>(&'b mut self, buf: &'b mut Vec) -> Result>; + + /// Switch raw buffering mode on. + /// + /// This is used when deserializing `RawValue`. + // #[cfg(feature = "raw_value")] + #[doc(hidden)] + fn begin_raw_buffering(&mut self); + + /// Switch raw buffering mode off and provides the raw buffered data to the + /// given visitor. + // #[cfg(feature = "raw_value")] + #[doc(hidden)] + fn end_raw_buffering(&mut self, visitor: V) -> Result + where + V: Visitor<'a>; } /// A wrapper to implement this crate's [Read] trait for [`std::io::Read`] trait implementations. @@ -138,6 +156,9 @@ where iter: io::Bytes, peeked_byte: Option, byte_offset: usize, + + // #[cfg(feature = "raw_value")] + raw_buffer: Option>, } #[cfg(feature = "std")] @@ -157,6 +178,9 @@ where iter: reader.bytes(), peeked_byte: None, byte_offset: 0, + + // #[cfg(feature = "raw_value")] + raw_buffer: None, } } } @@ -170,11 +194,23 @@ where fn next(&mut self) -> Option> { match self.peeked_byte.take() { Some(b) => { + // #[cfg(feature = "raw_value")] + { + if let Some(buf) = &mut self.raw_buffer { + buf.push(b); + } + } self.byte_offset += 1; Some(Ok(b)) } None => match self.iter.next() { Some(Ok(b)) => { + // #[cfg(feature = "raw_value")] + { + if let Some(buf) = &mut self.raw_buffer { + buf.push(b); + } + } self.byte_offset += 1; Some(Ok(b)) } @@ -402,6 +438,22 @@ where } } } + + // #[cfg(feature = "raw_value")] + fn begin_raw_buffering(&mut self) { + self.raw_buffer = Some(Vec::new()); + } + + // #[cfg(feature = "raw_value")] + fn end_raw_buffering(&mut self, visitor: V) -> Result + where + V: Visitor<'a>, + { + let raw = self.raw_buffer.take().unwrap(); + visitor.visit_map(OwnedRawDeserializer { + raw_value: Some(raw), + }) + } } /// A wrapper to implement this crate's [Read] trait for byte slices. @@ -410,6 +462,9 @@ where pub struct SliceRead<'a> { slice: &'a [u8], byte_offset: usize, + + // #[cfg(feature = "raw_value")] + raw_buffer: Option>, } impl<'a> SliceRead<'a> { @@ -419,6 +474,9 @@ impl<'a> SliceRead<'a> { SliceRead { slice, byte_offset: 0, + + // #[cfg(feature = "raw_value")] + raw_buffer: None, } } } @@ -428,6 +486,12 @@ impl<'a> Read<'a> for SliceRead<'a> { fn next(&mut self) -> Option> { if self.byte_offset < self.slice.len() { let b = self.slice[self.byte_offset]; + // #[cfg(feature = "raw_value")] + { + if let Some(buf) = &mut self.raw_buffer { + buf.push(b); + } + } self.byte_offset += 1; Some(Ok(b)) } else { @@ -484,6 +548,13 @@ impl<'a> Read<'a> for SliceRead<'a> { )); } + // #[cfg(feature = "raw_value")] + { + if let Some(buf) = &mut self.raw_buffer { + buf.extend_from_slice(&self.slice[start_idx..self.byte_offset]); + } + } + Ok(Ref::Source(&self.slice[start_idx..self.byte_offset])) } @@ -648,4 +719,20 @@ impl<'a> Read<'a> for SliceRead<'a> { } } } + + // #[cfg(feature = "raw_value")] + fn begin_raw_buffering(&mut self) { + self.raw_buffer = Some(Vec::new()); + } + + // #[cfg(feature = "raw_value")] + fn end_raw_buffering(&mut self, visitor: V) -> Result + where + V: Visitor<'a>, + { + let raw = self.raw_buffer.take().unwrap(); + visitor.visit_map(OwnedRawDeserializer { + raw_value: Some(raw), + }) + } } diff --git a/src/ser.rs b/src/ser.rs index 12ad3fc..4824b10 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -90,8 +90,8 @@ where type SerializeTuple = Self; type SerializeTupleStruct = Self; type SerializeTupleVariant = ser::Impossible<(), Error>; - type SerializeMap = SerializeMap<'a, W>; - type SerializeStruct = SerializeMap<'a, W>; + type SerializeMap = Compound<'a, W>; + type SerializeStruct = Compound<'a, W>; type SerializeStructVariant = ser::Impossible<(), Error>; #[inline] @@ -213,10 +213,17 @@ where } #[inline] - fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result<()> + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result<()> where T: ?Sized + Serialize, { + // #[cfg(feature = "raw_value")] + if name == crate::raw::TOKEN { + let raw = value.serialize(RawValueBytesSerializer)?; + self.writer.write_all(&raw)?; + return Ok(()); + } + value.serialize(self) } @@ -268,11 +275,16 @@ where #[inline] fn serialize_map(self, _len: Option) -> Result { self.writer.write_all(b"d")?; - Ok(SerializeMap::new(self)) + Ok(Compound::new_map(self)) } #[inline] - fn serialize_struct(self, _name: &'static str, len: usize) -> Result { + fn serialize_struct(self, name: &'static str, len: usize) -> Result { + // #[cfg(feature = "raw_value")] + if name == crate::raw::TOKEN { + return Ok(Compound::RawValue { ser: self }); + } + let _ = name; self.serialize_map(Some(len)) } @@ -292,6 +304,361 @@ where } } +struct RawValueBytesSerializer; + +impl ser::Serializer for RawValueBytesSerializer { + type Ok = Vec; + type Error = Error; + + type SerializeSeq = RawValueBytesSeq; + type SerializeTuple = ser::Impossible, Error>; + type SerializeTupleStruct = ser::Impossible, Error>; + type SerializeTupleVariant = ser::Impossible, Error>; + type SerializeMap = ser::Impossible, Error>; + type SerializeStruct = ser::Impossible, Error>; + type SerializeStructVariant = ser::Impossible, Error>; + + fn serialize_bytes(self, value: &[u8]) -> Result> { + Ok(value.to_vec()) + } + + fn serialize_seq(self, _len: Option) -> Result { + Ok(RawValueBytesSeq { bytes: Vec::new() }) + } + + fn serialize_bool(self, _value: bool) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_i8(self, _value: i8) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_i16(self, _value: i16) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_i32(self, _value: i32) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_i64(self, _value: i64) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_u8(self, _value: u8) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_u16(self, _value: u16) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_u32(self, _value: u32) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_u64(self, _value: u64) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_f32(self, _value: f32) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_f64(self, _value: f64) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_char(self, _value: char) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_str(self, _value: &str) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_none(self) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_some(self, _value: &T) -> Result> + where + T: ?Sized + Serialize, + { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_unit(self) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result> { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result> + where + T: ?Sized + Serialize, + { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result> + where + T: ?Sized + Serialize, + { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn is_human_readable(&self) -> bool { + false + } +} + +struct RawValueBytesSeq { + bytes: Vec, +} + +impl ser::SerializeSeq for RawValueBytesSeq { + type Ok = Vec; + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.bytes.push(value.serialize(RawValueByteSerializer)?); + Ok(()) + } + + fn end(self) -> Result> { + Ok(self.bytes) + } +} + +struct RawValueByteSerializer; + +impl ser::Serializer for RawValueByteSerializer { + type Ok = u8; + type Error = Error; + + type SerializeSeq = ser::Impossible; + type SerializeTuple = ser::Impossible; + type SerializeTupleStruct = ser::Impossible; + type SerializeTupleVariant = ser::Impossible; + type SerializeMap = ser::Impossible; + type SerializeStruct = ser::Impossible; + type SerializeStructVariant = ser::Impossible; + + fn serialize_u8(self, value: u8) -> Result { + Ok(value) + } + + fn serialize_bool(self, _value: bool) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_i8(self, _value: i8) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_i16(self, _value: i16) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_i32(self, _value: i32) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_i64(self, _value: i64) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_u16(self, _value: u16) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_u32(self, _value: u32) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_u64(self, _value: u64) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_f32(self, _value: f32) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_f64(self, _value: f64) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_char(self, _value: char) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_str(self, _value: &str) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_bytes(self, _value: &[u8]) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_none(self) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_some(self, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_unit(self) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + + fn is_human_readable(&self) -> bool { + false + } +} + impl ser::SerializeSeq for &mut Serializer where W: Write, @@ -361,41 +728,33 @@ where /// A serializer for writing map data. #[doc(hidden)] #[derive(Debug)] -pub struct SerializeMap<'a, W> { - ser: &'a mut Serializer, - entries: BTreeMap, Vec>, - current_key: Option>, +pub enum Compound<'a, W> { + Map { + ser: &'a mut Serializer, + entries: BTreeMap, Vec>, + current_key: Option>, + }, + // #[cfg(feature = "raw_value")] + RawValue { + ser: &'a mut Serializer, + }, } -impl<'a, W> SerializeMap<'a, W> +impl<'a, W> Compound<'a, W> where W: Write, { #[inline] - fn new(ser: &'a mut Serializer) -> Self { - SerializeMap { + fn new_map(ser: &'a mut Serializer) -> Self { + Compound::Map { ser, entries: BTreeMap::new(), current_key: None, } } - - #[inline] - fn end_map(&mut self) -> Result<()> { - if self.current_key.is_some() { - return Err(Error::with_kind(ErrorKind::KeyWithoutValue)); - } - - for (k, v) in &self.entries { - ser::Serializer::serialize_bytes(&mut *self.ser, k.as_ref())?; - self.ser.writer.write_all(v)?; - } - - Ok(()) - } } -impl ser::SerializeMap for SerializeMap<'_, W> +impl ser::SerializeMap for Compound<'_, W> where W: Write, { @@ -407,11 +766,16 @@ where where T: ?Sized + Serialize, { - if self.current_key.is_some() { - return Err(Error::with_kind(ErrorKind::KeyWithoutValue)); + match self { + Compound::Map { current_key, .. } => { + if current_key.is_some() { + return Err(Error::with_kind(ErrorKind::KeyWithoutValue)); + } + *current_key = Some(key.serialize(&mut MapKeySerializer {})?); + Ok(()) + } + Compound::RawValue { .. } => unreachable!(), } - self.current_key = Some(key.serialize(&mut MapKeySerializer {})?); - Ok(()) } #[inline] @@ -419,26 +783,50 @@ where where T: ?Sized + Serialize, { - let key = self - .current_key - .take() - .ok_or_else(|| Error::with_kind(ErrorKind::ValueWithoutKey))?; - let buf: Vec = Vec::new(); - let mut ser = Serializer::new(buf); - value.serialize(&mut ser)?; - self.entries.insert(key, ser.into_inner()); - Ok(()) + match self { + Compound::Map { + entries, + current_key, + .. + } => { + let key = current_key + .take() + .ok_or_else(|| Error::with_kind(ErrorKind::ValueWithoutKey))?; + let buf: Vec = Vec::new(); + let mut ser = Serializer::new(buf); // TODO: optimize? + value.serialize(&mut ser)?; + entries.insert(key, ser.into_inner()); + Ok(()) + } + Compound::RawValue { .. } => unreachable!(), + } } #[inline] fn end(mut self) -> Result<()> { - self.end_map()?; - self.ser.writer.write_all(b"e")?; - Ok(()) + match self { + Compound::Map { + ref mut ser, + entries, + current_key, + } => { + if current_key.is_some() { + return Err(Error::with_kind(ErrorKind::KeyWithoutValue)); + } + + for (k, v) in entries { + ser::Serializer::serialize_bytes(&mut **ser, k.as_ref())?; + ser.writer.write_all(&v)?; + } + ser.writer.write_all(b"e")?; + Ok(()) + } + Compound::RawValue { .. } => unreachable!(), + } } } -impl ser::SerializeStruct for SerializeMap<'_, W> +impl ser::SerializeStruct for Compound<'_, W> where W: Write, { @@ -450,20 +838,34 @@ where where T: ?Sized + Serialize, { - let key = key.serialize(&mut MapKeySerializer {})?; - - let buf: Vec = Vec::new(); - let mut ser = Serializer::new(buf); - value.serialize(&mut ser)?; - self.entries.insert(key, ser.into_inner()); - Ok(()) + match self { + Compound::Map { entries, .. } => { + let key = key.serialize(&mut MapKeySerializer {})?; + + let buf: Vec = Vec::new(); + let mut ser = Serializer::new(buf); + value.serialize(&mut ser)?; + entries.insert(key, ser.into_inner()); + Ok(()) + } + // #[cfg(feature = "raw_value")] + Compound::RawValue { ser } => { + if key == crate::raw::TOKEN { + value.serialize(&mut **ser) + } else { + Err(Error::with_kind(ErrorKind::UnsupportedType)) + } + } + } } #[inline] - fn end(mut self) -> Result<()> { - self.end_map()?; - self.ser.writer.write_all(b"e")?; - Ok(()) + fn end(self) -> Result<()> { + match self { + Compound::Map { .. } => ser::SerializeMap::end(self), + // #[cfg(feature = "raw_value")] + Compound::RawValue { .. } => Ok(()), + } } }