diff --git a/zebra-network/src/protocol/external/codec.rs b/zebra-network/src/protocol/external/codec.rs index d867539c..f854db40 100644 --- a/zebra-network/src/protocol/external/codec.rs +++ b/zebra-network/src/protocol/external/codec.rs @@ -4,7 +4,7 @@ use std::fmt; use std::io::{Cursor, Read, Write}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -use bytes::BytesMut; +use bytes::{BufMut, BytesMut}; use chrono::{TimeZone, Utc}; use tokio_util::codec::{Decoder, Encoder}; @@ -109,20 +109,15 @@ impl Encoder for Codec { fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { use Error::Parse; - // XXX(HACK): this is inefficient and does an extra allocation. - // instead, we should have a size estimator for the message, reserve - // that much space, write the header (with zeroed checksum), then the body, - // then write the computed checksum in-place. for now, just do an extra alloc. - let mut body = Vec::new(); - self.write_body(&item, &mut body)?; + let body_length = self.body_length(&item); - if body.len() > self.builder.max_len { + if body_length > self.builder.max_len { return Err(Parse("body length exceeded maximum size")); } if let Some(label) = self.builder.metrics_label.clone() { - metrics::counter!("bytes.written", (body.len() + HEADER_LEN) as u64, "addr" => label); + metrics::counter!("bytes.written", (body_length + HEADER_LEN) as u64, "addr" => label); } use Message::*; @@ -152,26 +147,58 @@ impl Encoder for Codec { FilterAdd { .. } => b"filteradd\0\0\0", FilterClear { .. } => b"filterclear\0", }; - trace!(?item, len = body.len()); + trace!(?item, len = body_length); - // XXX this should write directly into the buffer, - // but leave it for now until we fix the issue above. - let mut header = [0u8; HEADER_LEN]; - let mut header_writer = Cursor::new(&mut header[..]); - header_writer.write_all(&Magic::from(self.builder.network).0[..])?; - header_writer.write_all(command)?; - header_writer.write_u32::(body.len() as u32)?; - header_writer.write_all(&sha256d::Checksum::from(&body[..]).0)?; + dst.reserve(HEADER_LEN + body_length); + let start_len = dst.len(); + { + let dst = &mut dst.writer(); + dst.write_all(&Magic::from(self.builder.network).0[..])?; + dst.write_all(command)?; + dst.write_u32::(body_length as u32)?; - dst.reserve(HEADER_LEN + body.len()); - dst.extend_from_slice(&header); - dst.extend_from_slice(&body); + // We zero the checksum at first, and compute it later + // after the body has been written. + dst.write_u32::(0)?; + + self.write_body(&item, dst)?; + } + let checksum = sha256d::Checksum::from(&dst[start_len + HEADER_LEN..]); + dst[start_len + 20..][..4].copy_from_slice(&checksum.0); Ok(()) } } impl Codec { + /// Obtain the size of the body of a given message. This will match the + /// number of bytes written to the writer provided to `write_body` for the + /// same message. + /// + /// TODO: Replace with a size estimate, to avoid multiple serializations + /// for large data structures like lists, blocks, and transactions. + /// See #1774. + fn body_length(&self, msg: &Message) -> usize { + struct FakeWriter(usize); + + impl std::io::Write for FakeWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.0 += buf.len(); + + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } + } + + let mut writer = FakeWriter(0); + self.write_body(msg, &mut writer) + .expect("writer should never fail"); + writer.0 + } + /// Write the body of the message into the given writer. This allows writing /// the message body prior to writing the header, so that the header can /// contain a checksum of the message body.