diff --git a/zebra-network/src/peer/connection.rs b/zebra-network/src/peer/connection.rs index 893b6388..38c78da8 100644 --- a/zebra-network/src/peer/connection.rs +++ b/zebra-network/src/peer/connection.rs @@ -378,10 +378,7 @@ where } Either::Left((Some(Err(e)), _)) => self.fail_with(e), Either::Left((Some(Ok(msg)), _)) => { - match self.handle_message_as_request(msg).await { - Ok(()) => {} - Err(e) => self.fail_with(e), - } + self.handle_message_as_request(msg).await } Either::Right((None, _)) => { trace!("client_rx closed, ending connection"); @@ -437,10 +434,7 @@ where if let Some(msg) = request_msg { // do NOT instrument with the request span, this is // independent work - match self.handle_message_as_request(msg).await { - Ok(()) => {} - Err(e) => self.fail_with(e), - } + self.handle_message_as_request(msg).await; } else { // Otherwise, check whether the handler is finished // processing messages and update the state. @@ -571,53 +565,18 @@ where /// NOTE: the caller should use .instrument(msg.span) to instrument the function. async fn handle_client_request(&mut self, req: InProgressClientRequest) { trace!(?req.request); + use Request::*; use State::*; + let InProgressClientRequest { request, tx, span } = req; - if req.tx.is_canceled() { + if tx.is_canceled() { metrics::counter!("peer.canceled", 1); tracing::debug!("ignoring canceled request"); return; } - let new_state_result = self._handle_client_request(req).await; - - // Updates state or fails. - match new_state_result { - Ok(AwaitingRequest) => { - self.state = AwaitingRequest; - self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); - } - Ok(new_state @ AwaitingResponse { .. }) => { - self.state = new_state; - self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); - } - Err((e, tx)) => { - let e = SharedPeerError::from(e); - let _ = tx.send(Err(e.clone())); - self.fail_with(e); - } - // unreachable states - Ok(Failed) => unreachable!( - "failed client requests must use fail_with(error) to reach a Failed state." - ), - }; - } - - async fn _handle_client_request( - &mut self, - req: InProgressClientRequest, - ) -> Result< - State, - ( - SerializationError, - MustUseOneshotSender>, - ), - > { - use Request::*; - use State::*; - let InProgressClientRequest { request, tx, span } = req; - - match (&self.state, request) { + // These matches return a Result with (new_state, Option) or an (error, Sender) + let new_state_result = match (&self.state, request) { (Failed, request) => panic!( "failed connection cannot handle new request: {:?}, client_receiver: {:?}", request, @@ -630,23 +589,25 @@ where self.client_rx ), (AwaitingRequest, Peers) => match self.peer_tx.send(Message::GetAddr).await { - Ok(()) => Ok( + Ok(()) => Ok(( AwaitingResponse { handler: Handler::Peers, tx, span, }, - ), + None, + )), Err(e) => Err((e, tx)), }, (AwaitingRequest, Ping(nonce)) => match self.peer_tx.send(Message::Ping(nonce)).await { - Ok(()) => Ok( + Ok(()) => Ok(( AwaitingResponse { handler: Handler::Ping(nonce), tx, span, }, - ), + None, + )), Err(e) => Err((e, tx)), }, (AwaitingRequest, BlocksByHash(hashes)) => { @@ -657,7 +618,7 @@ where )) .await { - Ok(()) => Ok( + Ok(()) => Ok(( AwaitingResponse { handler: Handler::BlocksByHash { blocks: Vec::with_capacity(hashes.len()), @@ -666,7 +627,8 @@ where tx, span, }, - ), + None, + )), Err(e) => Err((e, tx)), } } @@ -678,7 +640,7 @@ where )) .await { - Ok(()) => Ok( + Ok(()) => Ok(( AwaitingResponse { handler: Handler::TransactionsByHash { transactions: Vec::with_capacity(hashes.len()), @@ -687,7 +649,8 @@ where tx, span, }, - ), + None, + )), Err(e) => Err((e, tx)), } } @@ -697,13 +660,14 @@ where .send(Message::GetBlocks { known_blocks, stop }) .await { - Ok(()) => Ok( + Ok(()) => Ok(( AwaitingResponse { handler: Handler::FindBlocks, tx, span, }, - ), + None, + )), Err(e) => Err((e, tx)), } } @@ -713,36 +677,33 @@ where .send(Message::GetHeaders { known_blocks, stop }) .await { - Ok(()) => Ok( + Ok(()) => Ok(( AwaitingResponse { handler: Handler::FindHeaders, tx, span, }, - ), + None, + )), Err(e) => Err((e, tx)), } } (AwaitingRequest, MempoolTransactions) => { match self.peer_tx.send(Message::Mempool).await { - Ok(()) => Ok( + Ok(()) => Ok(( AwaitingResponse { handler: Handler::MempoolTransactions, tx, span, }, - ), + None, + )), Err(e) => Err((e, tx)), } } (AwaitingRequest, PushTransaction(transaction)) => { match self.peer_tx.send(Message::Tx(transaction)).await { - Ok(()) => { - // Since we're not waiting for further messages, we need to - // send a response before dropping tx. - let _ = tx.send(Ok(Response::Nil)); - Ok(AwaitingRequest) - }, + Ok(()) => Ok((AwaitingRequest, Some(tx))), Err(e) => Err((e, tx)), } } @@ -752,78 +713,110 @@ where .send(Message::Inv(hashes.iter().map(|h| (*h).into()).collect())) .await { - Ok(()) => { - // Since we're not waiting for further messages, we need to - // send a response before dropping tx. - let _ = tx.send(Ok(Response::Nil)); - Ok(AwaitingRequest) - }, + Ok(()) => Ok((AwaitingRequest, Some(tx))), Err(e) => Err((e, tx)), } } (AwaitingRequest, AdvertiseBlock(hash)) => { match self.peer_tx.send(Message::Inv(vec![hash.into()])).await { - Ok(()) => { - // Since we're not waiting for further messages, we need to - // send a response before dropping tx. - let _ = tx.send(Ok(Response::Nil)); - Ok(AwaitingRequest) - }, + Ok(()) => Ok((AwaitingRequest, Some(tx))), Err(e) => Err((e, tx)), } } - } + }; + // Updates state or fails. Sends the error on the Sender if it is Some. + match new_state_result { + Ok((AwaitingRequest, Some(tx))) => { + // Since we're not waiting for further messages, we need to + // send a response before dropping tx. + let _ = tx.send(Ok(Response::Nil)); + self.state = AwaitingRequest; + self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); + } + Ok((new_state @ AwaitingResponse { .. }, None)) => { + self.state = new_state; + self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); + } + Err((e, tx)) => { + let e = SharedPeerError::from(e); + let _ = tx.send(Err(e.clone())); + self.fail_with(e); + } + // unreachable states + Ok((Failed, tx)) => unreachable!( + "failed client requests must use fail_with(error) to reach a Failed state. tx: {:?}", + tx + ), + Ok((AwaitingRequest, None)) => unreachable!( + "successful AwaitingRequest states must send a response on tx, but tx is None", + ), + Ok((new_state @ AwaitingResponse { .. }, Some(tx))) => unreachable!( + "successful AwaitingResponse states must keep tx, but tx is Some: {:?} for: {:?}", + tx, new_state, + ), + }; } // This function has its own span, because we're creating a new work // context (namely, the work of processing the inbound msg as a request) #[instrument(name = "msg_as_req", skip(self, msg), fields(%msg))] - async fn handle_message_as_request(&mut self, msg: Message) -> Result<(), PeerError> { + async fn handle_message_as_request(&mut self, msg: Message) { trace!(?msg); let req = match msg { Message::Ping(nonce) => { trace!(?nonce, "responding to heartbeat"); - self.peer_tx.send(Message::Pong(nonce)).await?; - return Ok(()); + if let Err(e) = self.peer_tx.send(Message::Pong(nonce)).await { + self.fail_with(e); + } + return; } // These messages shouldn't be sent outside of a handshake. - Message::Version { .. } => Err(PeerError::DuplicateHandshake)?, - Message::Verack { .. } => Err(PeerError::DuplicateHandshake)?, + Message::Version { .. } => { + self.fail_with(PeerError::DuplicateHandshake); + return; + } + Message::Verack { .. } => { + self.fail_with(PeerError::DuplicateHandshake); + return; + } // These messages should already be handled as a response if they // could be a response, so if we see them here, they were either // sent unsolicited, or they were sent in response to a canceled request // that we've already forgotten about. Message::Reject { .. } => { tracing::debug!("got reject message unsolicited or from canceled request"); - return Ok(()); + return; } Message::NotFound { .. } => { tracing::debug!("got notfound message unsolicited or from canceled request"); - return Ok(()); + return; } Message::Pong(_) => { tracing::debug!("got pong message unsolicited or from canceled request"); - return Ok(()); + return; } Message::Block(_) => { tracing::debug!("got block message unsolicited or from canceled request"); - return Ok(()); + return; } Message::Headers(_) => { tracing::debug!("got headers message unsolicited or from canceled request"); - return Ok(()); + return; } // These messages should never be sent by peers. Message::FilterLoad { .. } | Message::FilterAdd { .. } - | Message::FilterClear { .. } => Err(PeerError::UnsupportedMessage( - "got BIP11 message without advertising NODE_BLOOM", - ))?, + | Message::FilterClear { .. } => { + self.fail_with(PeerError::UnsupportedMessage( + "got BIP11 message without advertising NODE_BLOOM", + )); + return; + } // Zebra crawls the network proactively, to prevent // peers from inserting data into our address book. Message::Addr(_) => { trace!("ignoring unsolicited addr message"); - return Ok(()); + return; } Message::Tx(transaction) => Request::PushTransaction(transaction), Message::Inv(items) => match &items[..] { @@ -835,7 +828,10 @@ where { Request::TransactionsByHash(transaction_hashes(&items).collect()) } - _ => Err(PeerError::WrongMessage("inv with mixed item types"))?, + _ => { + self.fail_with(PeerError::WrongMessage("inv with mixed item types")); + return; + } }, Message::GetData(items) => match &items[..] { [InventoryHash::Block(_), rest @ ..] @@ -850,7 +846,10 @@ where { Request::TransactionsByHash(transaction_hashes(&items).collect()) } - _ => Err(PeerError::WrongMessage("getdata with mixed item types"))?, + _ => { + self.fail_with(PeerError::WrongMessage("getdata with mixed item types")); + return; + } }, Message::GetAddr => Request::Peers, Message::GetBlocks { known_blocks, stop } => Request::FindBlocks { known_blocks, stop }, @@ -860,9 +859,7 @@ where Message::Mempool => Request::MempoolTransactions, }; - self.drive_peer_request(req).await?; - - Ok(()) + self.drive_peer_request(req).await } /// Given a `req` originating from the peer, drive it to completion and send @@ -870,14 +867,15 @@ where /// processing the request (e.g., the service is shedding load), then we call /// fail_with to terminate the entire peer connection, shrinking the number /// of connected peers. - async fn drive_peer_request(&mut self, req: Request) -> Result<(), PeerError> { + async fn drive_peer_request(&mut self, req: Request) { trace!(?req); use tower::{load_shed::error::Overloaded, ServiceExt}; if self.svc.ready_and().await.is_err() { // Treat all service readiness errors as Overloaded // TODO: treat `TryRecvError::Closed` in `Inbound::poll_ready` as a fatal error (#1655) - Err(PeerError::Overloaded)? + self.fail_with(PeerError::Overloaded); + return; } let rsp = match self.svc.call(req).await { @@ -885,7 +883,7 @@ where if e.is::() { tracing::warn!("inbound service is overloaded, closing connection"); metrics::counter!("pool.closed.loadshed", 1); - Err(PeerError::Overloaded)? + self.fail_with(PeerError::Overloaded); } else { // We could send a reject to the remote peer, but that might cause // them to disconnect, and we might be using them to sync blocks. @@ -896,40 +894,58 @@ where client_receiver = ?self.client_rx, "error processing peer request"); } - return Ok(()); + return; } Ok(rsp) => rsp, }; match rsp { - Response::Nil => { /* generic success, do nothing */ }, - Response::Peers(addrs) => self.peer_tx.send(Message::Addr(addrs)).await?, + Response::Nil => { /* generic success, do nothing */ } + Response::Peers(addrs) => { + if let Err(e) = self.peer_tx.send(Message::Addr(addrs)).await { + self.fail_with(e); + } + } Response::Transactions(transactions) => { // Generate one tx message per transaction. for transaction in transactions.into_iter() { - self.peer_tx.send(Message::Tx(transaction)).await?; + if let Err(e) = self.peer_tx.send(Message::Tx(transaction)).await { + self.fail_with(e); + } } } Response::Blocks(blocks) => { // Generate one block message per block. for block in blocks.into_iter() { - self.peer_tx.send(Message::Block(block)).await?; + if let Err(e) = self.peer_tx.send(Message::Block(block)).await { + self.fail_with(e); + } } } Response::BlockHashes(hashes) => { - self.peer_tx + if let Err(e) = self + .peer_tx .send(Message::Inv(hashes.into_iter().map(Into::into).collect())) - .await? + .await + { + self.fail_with(e) + } + } + Response::BlockHeaders(headers) => { + if let Err(e) = self.peer_tx.send(Message::Headers(headers)).await { + self.fail_with(e) + } } - Response::BlockHeaders(headers) => self.peer_tx.send(Message::Headers(headers)).await?, Response::TransactionHashes(hashes) => { - self.peer_tx + if let Err(e) = self + .peer_tx .send(Message::Inv(hashes.into_iter().map(Into::into).collect())) - .await? + .await + { + self.fail_with(e) + } } } - - Ok(()) } }