diff --git a/zebra-network/src/protocol/external/codec.rs b/zebra-network/src/protocol/external/codec.rs index 0b665235..10b79dad 100644 --- a/zebra-network/src/protocol/external/codec.rs +++ b/zebra-network/src/protocol/external/codec.rs @@ -1,5 +1,6 @@ //! A Tokio codec mapping byte streams to Bitcoin message streams. +use std::convert::TryInto; use std::fmt; use std::io::{Cursor, Read, Write}; @@ -278,7 +279,7 @@ impl Codec { ref tweak, ref flags, } => { - filter.zcash_serialize(&mut writer)?; + writer.write_all(&filter.0)?; writer.write_u32::(*hash_functions_count)?; writer.write_u32::(tweak.0)?; writer.write_u8(*flags)?; @@ -598,9 +599,19 @@ impl Codec { Ok(Message::Mempool) } - fn read_filterload(&self, mut reader: R) -> Result { + fn read_filterload(&self, mut reader: Cursor<&BytesMut>) -> Result { + const MAX_FILTER_LENGTH: usize = 36000; + const FILTERLOAD_REMAINDER_LENGTH: usize = 4 + 4 + 1; + + let filter_length: usize = (reader.get_ref().len() - FILTERLOAD_REMAINDER_LENGTH) + .try_into() + .unwrap(); + + let mut filter_bytes = vec![0; std::cmp::min(filter_length, MAX_FILTER_LENGTH)]; + reader.read_exact(&mut filter_bytes)?; + Ok(Message::FilterLoad { - filter: Filter::zcash_deserialize(&mut reader)?, + filter: Filter(filter_bytes), hash_functions_count: reader.read_u32::()?, tweak: Tweak(reader.read_u32::()?), flags: reader.read_u8()?, @@ -679,6 +690,76 @@ mod tests { assert_eq!(v, v_parsed); } + #[test] + fn filterload_message_round_trip() { + let rt = Runtime::new().unwrap(); + + let v = Message::FilterLoad { + filter: Filter(vec![0; 35999]), + hash_functions_count: 0, + tweak: Tweak(0), + flags: 0, + }; + + use tokio::codec::{FramedRead, FramedWrite}; + use tokio::prelude::*; + let v_bytes = rt.block_on(async { + let mut bytes = Vec::new(); + { + let mut fw = FramedWrite::new(&mut bytes, Codec::builder().finish()); + fw.send(v.clone()) + .await + .expect("message should be serialized"); + } + bytes + }); + + let v_parsed = rt.block_on(async { + let mut fr = FramedRead::new(Cursor::new(&v_bytes), Codec::builder().finish()); + fr.next() + .await + .expect("a next message should be available") + .expect("that message should deserialize") + }); + + assert_eq!(v, v_parsed); + } + + #[test] + fn filterload_message_too_large_round_trip() { + let rt = Runtime::new().unwrap(); + + let v = Message::FilterLoad { + filter: Filter(vec![0; 40000]), + hash_functions_count: 0, + tweak: Tweak(0), + flags: 0, + }; + + use tokio::codec::{FramedRead, FramedWrite}; + use tokio::prelude::*; + let v_bytes = rt.block_on(async { + let mut bytes = Vec::new(); + { + let mut fw = FramedWrite::new(&mut bytes, Codec::builder().finish()); + fw.send(v.clone()) + .await + .expect("message should be serialized"); + } + bytes + }); + + let v_parsed = rt.block_on(async { + let mut fr = FramedRead::new(Cursor::new(&v_bytes), Codec::builder().finish()); + fr.next() + .await + .expect("a next message should be available") + .expect("that message should deserialize") + }); + + assert_ne!(v, v_parsed); + } + #[test] fn decode_state_debug() { assert_eq!(format!("{:?}", DecodeState::Head), "DecodeState::Head");