diff --git a/zebra-network/src/peer/connection.rs b/zebra-network/src/peer/connection.rs index 38c78da8..893b6388 100644 --- a/zebra-network/src/peer/connection.rs +++ b/zebra-network/src/peer/connection.rs @@ -378,7 +378,10 @@ where } Either::Left((Some(Err(e)), _)) => self.fail_with(e), Either::Left((Some(Ok(msg)), _)) => { - self.handle_message_as_request(msg).await + match self.handle_message_as_request(msg).await { + Ok(()) => {} + Err(e) => self.fail_with(e), + } } Either::Right((None, _)) => { trace!("client_rx closed, ending connection"); @@ -434,7 +437,10 @@ where if let Some(msg) = request_msg { // do NOT instrument with the request span, this is // independent work - self.handle_message_as_request(msg).await; + match self.handle_message_as_request(msg).await { + Ok(()) => {} + Err(e) => self.fail_with(e), + } } else { // Otherwise, check whether the handler is finished // processing messages and update the state. @@ -565,18 +571,53 @@ where /// NOTE: the caller should use .instrument(msg.span) to instrument the function. async fn handle_client_request(&mut self, req: InProgressClientRequest) { trace!(?req.request); - use Request::*; use State::*; - let InProgressClientRequest { request, tx, span } = req; - if tx.is_canceled() { + if req.tx.is_canceled() { metrics::counter!("peer.canceled", 1); tracing::debug!("ignoring canceled request"); return; } - // These matches return a Result with (new_state, Option) or an (error, Sender) - let new_state_result = match (&self.state, request) { + let new_state_result = self._handle_client_request(req).await; + + // Updates state or fails. + match new_state_result { + Ok(AwaitingRequest) => { + self.state = AwaitingRequest; + self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); + } + Ok(new_state @ AwaitingResponse { .. }) => { + self.state = new_state; + self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); + } + Err((e, tx)) => { + let e = SharedPeerError::from(e); + let _ = tx.send(Err(e.clone())); + self.fail_with(e); + } + // unreachable states + Ok(Failed) => unreachable!( + "failed client requests must use fail_with(error) to reach a Failed state." + ), + }; + } + + async fn _handle_client_request( + &mut self, + req: InProgressClientRequest, + ) -> Result< + State, + ( + SerializationError, + MustUseOneshotSender>, + ), + > { + use Request::*; + use State::*; + let InProgressClientRequest { request, tx, span } = req; + + match (&self.state, request) { (Failed, request) => panic!( "failed connection cannot handle new request: {:?}, client_receiver: {:?}", request, @@ -589,25 +630,23 @@ where self.client_rx ), (AwaitingRequest, Peers) => match self.peer_tx.send(Message::GetAddr).await { - Ok(()) => Ok(( + Ok(()) => Ok( AwaitingResponse { handler: Handler::Peers, tx, span, }, - None, - )), + ), Err(e) => Err((e, tx)), }, (AwaitingRequest, Ping(nonce)) => match self.peer_tx.send(Message::Ping(nonce)).await { - Ok(()) => Ok(( + Ok(()) => Ok( AwaitingResponse { handler: Handler::Ping(nonce), tx, span, }, - None, - )), + ), Err(e) => Err((e, tx)), }, (AwaitingRequest, BlocksByHash(hashes)) => { @@ -618,7 +657,7 @@ where )) .await { - Ok(()) => Ok(( + Ok(()) => Ok( AwaitingResponse { handler: Handler::BlocksByHash { blocks: Vec::with_capacity(hashes.len()), @@ -627,8 +666,7 @@ where tx, span, }, - None, - )), + ), Err(e) => Err((e, tx)), } } @@ -640,7 +678,7 @@ where )) .await { - Ok(()) => Ok(( + Ok(()) => Ok( AwaitingResponse { handler: Handler::TransactionsByHash { transactions: Vec::with_capacity(hashes.len()), @@ -649,8 +687,7 @@ where tx, span, }, - None, - )), + ), Err(e) => Err((e, tx)), } } @@ -660,14 +697,13 @@ where .send(Message::GetBlocks { known_blocks, stop }) .await { - Ok(()) => Ok(( + Ok(()) => Ok( AwaitingResponse { handler: Handler::FindBlocks, tx, span, }, - None, - )), + ), Err(e) => Err((e, tx)), } } @@ -677,33 +713,36 @@ where .send(Message::GetHeaders { known_blocks, stop }) .await { - Ok(()) => Ok(( + Ok(()) => Ok( AwaitingResponse { handler: Handler::FindHeaders, tx, span, }, - None, - )), + ), Err(e) => Err((e, tx)), } } (AwaitingRequest, MempoolTransactions) => { match self.peer_tx.send(Message::Mempool).await { - Ok(()) => Ok(( + Ok(()) => Ok( AwaitingResponse { handler: Handler::MempoolTransactions, tx, span, }, - None, - )), + ), Err(e) => Err((e, tx)), } } (AwaitingRequest, PushTransaction(transaction)) => { match self.peer_tx.send(Message::Tx(transaction)).await { - Ok(()) => Ok((AwaitingRequest, Some(tx))), + Ok(()) => { + // Since we're not waiting for further messages, we need to + // send a response before dropping tx. + let _ = tx.send(Ok(Response::Nil)); + Ok(AwaitingRequest) + }, Err(e) => Err((e, tx)), } } @@ -713,110 +752,78 @@ where .send(Message::Inv(hashes.iter().map(|h| (*h).into()).collect())) .await { - Ok(()) => Ok((AwaitingRequest, Some(tx))), + Ok(()) => { + // Since we're not waiting for further messages, we need to + // send a response before dropping tx. + let _ = tx.send(Ok(Response::Nil)); + Ok(AwaitingRequest) + }, Err(e) => Err((e, tx)), } } (AwaitingRequest, AdvertiseBlock(hash)) => { match self.peer_tx.send(Message::Inv(vec![hash.into()])).await { - Ok(()) => Ok((AwaitingRequest, Some(tx))), + Ok(()) => { + // Since we're not waiting for further messages, we need to + // send a response before dropping tx. + let _ = tx.send(Ok(Response::Nil)); + Ok(AwaitingRequest) + }, Err(e) => Err((e, tx)), } } - }; - // Updates state or fails. Sends the error on the Sender if it is Some. - match new_state_result { - Ok((AwaitingRequest, Some(tx))) => { - // Since we're not waiting for further messages, we need to - // send a response before dropping tx. - let _ = tx.send(Ok(Response::Nil)); - self.state = AwaitingRequest; - self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); - } - Ok((new_state @ AwaitingResponse { .. }, None)) => { - self.state = new_state; - self.request_timer = Some(sleep(constants::REQUEST_TIMEOUT)); - } - Err((e, tx)) => { - let e = SharedPeerError::from(e); - let _ = tx.send(Err(e.clone())); - self.fail_with(e); - } - // unreachable states - Ok((Failed, tx)) => unreachable!( - "failed client requests must use fail_with(error) to reach a Failed state. tx: {:?}", - tx - ), - Ok((AwaitingRequest, None)) => unreachable!( - "successful AwaitingRequest states must send a response on tx, but tx is None", - ), - Ok((new_state @ AwaitingResponse { .. }, Some(tx))) => unreachable!( - "successful AwaitingResponse states must keep tx, but tx is Some: {:?} for: {:?}", - tx, new_state, - ), - }; + } } // This function has its own span, because we're creating a new work // context (namely, the work of processing the inbound msg as a request) #[instrument(name = "msg_as_req", skip(self, msg), fields(%msg))] - async fn handle_message_as_request(&mut self, msg: Message) { + async fn handle_message_as_request(&mut self, msg: Message) -> Result<(), PeerError> { trace!(?msg); let req = match msg { Message::Ping(nonce) => { trace!(?nonce, "responding to heartbeat"); - if let Err(e) = self.peer_tx.send(Message::Pong(nonce)).await { - self.fail_with(e); - } - return; + self.peer_tx.send(Message::Pong(nonce)).await?; + return Ok(()); } // These messages shouldn't be sent outside of a handshake. - Message::Version { .. } => { - self.fail_with(PeerError::DuplicateHandshake); - return; - } - Message::Verack { .. } => { - self.fail_with(PeerError::DuplicateHandshake); - return; - } + Message::Version { .. } => Err(PeerError::DuplicateHandshake)?, + Message::Verack { .. } => Err(PeerError::DuplicateHandshake)?, // These messages should already be handled as a response if they // could be a response, so if we see them here, they were either // sent unsolicited, or they were sent in response to a canceled request // that we've already forgotten about. Message::Reject { .. } => { tracing::debug!("got reject message unsolicited or from canceled request"); - return; + return Ok(()); } Message::NotFound { .. } => { tracing::debug!("got notfound message unsolicited or from canceled request"); - return; + return Ok(()); } Message::Pong(_) => { tracing::debug!("got pong message unsolicited or from canceled request"); - return; + return Ok(()); } Message::Block(_) => { tracing::debug!("got block message unsolicited or from canceled request"); - return; + return Ok(()); } Message::Headers(_) => { tracing::debug!("got headers message unsolicited or from canceled request"); - return; + return Ok(()); } // These messages should never be sent by peers. Message::FilterLoad { .. } | Message::FilterAdd { .. } - | Message::FilterClear { .. } => { - self.fail_with(PeerError::UnsupportedMessage( - "got BIP11 message without advertising NODE_BLOOM", - )); - return; - } + | Message::FilterClear { .. } => Err(PeerError::UnsupportedMessage( + "got BIP11 message without advertising NODE_BLOOM", + ))?, // Zebra crawls the network proactively, to prevent // peers from inserting data into our address book. Message::Addr(_) => { trace!("ignoring unsolicited addr message"); - return; + return Ok(()); } Message::Tx(transaction) => Request::PushTransaction(transaction), Message::Inv(items) => match &items[..] { @@ -828,10 +835,7 @@ where { Request::TransactionsByHash(transaction_hashes(&items).collect()) } - _ => { - self.fail_with(PeerError::WrongMessage("inv with mixed item types")); - return; - } + _ => Err(PeerError::WrongMessage("inv with mixed item types"))?, }, Message::GetData(items) => match &items[..] { [InventoryHash::Block(_), rest @ ..] @@ -846,10 +850,7 @@ where { Request::TransactionsByHash(transaction_hashes(&items).collect()) } - _ => { - self.fail_with(PeerError::WrongMessage("getdata with mixed item types")); - return; - } + _ => Err(PeerError::WrongMessage("getdata with mixed item types"))?, }, Message::GetAddr => Request::Peers, Message::GetBlocks { known_blocks, stop } => Request::FindBlocks { known_blocks, stop }, @@ -859,7 +860,9 @@ where Message::Mempool => Request::MempoolTransactions, }; - self.drive_peer_request(req).await + self.drive_peer_request(req).await?; + + Ok(()) } /// Given a `req` originating from the peer, drive it to completion and send @@ -867,15 +870,14 @@ where /// processing the request (e.g., the service is shedding load), then we call /// fail_with to terminate the entire peer connection, shrinking the number /// of connected peers. - async fn drive_peer_request(&mut self, req: Request) { + async fn drive_peer_request(&mut self, req: Request) -> Result<(), PeerError> { trace!(?req); use tower::{load_shed::error::Overloaded, ServiceExt}; if self.svc.ready_and().await.is_err() { // Treat all service readiness errors as Overloaded // TODO: treat `TryRecvError::Closed` in `Inbound::poll_ready` as a fatal error (#1655) - self.fail_with(PeerError::Overloaded); - return; + Err(PeerError::Overloaded)? } let rsp = match self.svc.call(req).await { @@ -883,7 +885,7 @@ where if e.is::() { tracing::warn!("inbound service is overloaded, closing connection"); metrics::counter!("pool.closed.loadshed", 1); - self.fail_with(PeerError::Overloaded); + Err(PeerError::Overloaded)? } else { // We could send a reject to the remote peer, but that might cause // them to disconnect, and we might be using them to sync blocks. @@ -894,58 +896,40 @@ where client_receiver = ?self.client_rx, "error processing peer request"); } - return; + return Ok(()); } Ok(rsp) => rsp, }; match rsp { - Response::Nil => { /* generic success, do nothing */ } - Response::Peers(addrs) => { - if let Err(e) = self.peer_tx.send(Message::Addr(addrs)).await { - self.fail_with(e); - } - } + Response::Nil => { /* generic success, do nothing */ }, + Response::Peers(addrs) => self.peer_tx.send(Message::Addr(addrs)).await?, Response::Transactions(transactions) => { // Generate one tx message per transaction. for transaction in transactions.into_iter() { - if let Err(e) = self.peer_tx.send(Message::Tx(transaction)).await { - self.fail_with(e); - } + self.peer_tx.send(Message::Tx(transaction)).await?; } } Response::Blocks(blocks) => { // Generate one block message per block. for block in blocks.into_iter() { - if let Err(e) = self.peer_tx.send(Message::Block(block)).await { - self.fail_with(e); - } + self.peer_tx.send(Message::Block(block)).await?; } } Response::BlockHashes(hashes) => { - if let Err(e) = self - .peer_tx + self.peer_tx .send(Message::Inv(hashes.into_iter().map(Into::into).collect())) - .await - { - self.fail_with(e) - } - } - Response::BlockHeaders(headers) => { - if let Err(e) = self.peer_tx.send(Message::Headers(headers)).await { - self.fail_with(e) - } + .await? } + Response::BlockHeaders(headers) => self.peer_tx.send(Message::Headers(headers)).await?, Response::TransactionHashes(hashes) => { - if let Err(e) = self - .peer_tx + self.peer_tx .send(Message::Inv(hashes.into_iter().map(Into::into).collect())) - .await - { - self.fail_with(e) - } + .await? } } + + Ok(()) } }