diff --git a/zebra-rpc/src/methods.rs b/zebra-rpc/src/methods.rs index 59c80865..09cb120c 100644 --- a/zebra-rpc/src/methods.rs +++ b/zebra-rpc/src/methods.rs @@ -19,6 +19,7 @@ use tower::{buffer::Buffer, Service, ServiceExt}; use tracing::Instrument; use zebra_chain::{ + amount::{Amount, NonNegative}, block::{self, Height, SerializedBlock}, chain_tip::ChainTip, parameters::{ConsensusBranchId, Network, NetworkUpgrade}, @@ -69,6 +70,32 @@ pub trait Rpc { #[rpc(name = "getblockchaininfo")] fn get_blockchain_info(&self) -> Result; + /// Returns the total balance of a provided `addresses` in an [`AddressBalance`] instance. + /// + /// zcashd reference: [`getaddressbalance`](https://zcash.github.io/rpc/getaddressbalance.html) + /// + /// # Parameters + /// + /// - `address_strings`: (map) A JSON map with a single entry + /// - `addresses`: (array of strings) A list of base-58 encoded addresses. + /// + /// # Notes + /// + /// zcashd also accepts a single string parameter instead of an array of strings, but Zebra + /// doesn't because lightwalletd always calls this RPC with an array of addresses. + /// + /// zcashd also returns the total amount of Zatoshis received by the addresses, but Zebra + /// doesn't because lightwalletd doesn't use that information. + /// + /// The RPC documentation says that the returned object has a string `balance` field, but + /// zcashd actually [returns an + /// integer](https://github.com/zcash/lightwalletd/blob/bdaac63f3ee0dbef62bde04f6817a9f90d483b00/common/common.go#L128-L130). + #[rpc(name = "getaddressbalance")] + fn get_address_balance( + &self, + address_strings: AddressStrings, + ) -> BoxFuture>; + /// Sends the raw bytes of a signed transaction to the local node's mempool, if the transaction is valid. /// Returns the [`SentTransactionHash`] for the transaction, as a JSON string. /// @@ -369,6 +396,40 @@ where Ok(response) } + fn get_address_balance( + &self, + address_strings: AddressStrings, + ) -> BoxFuture> { + let state = self.state.clone(); + + async move { + let addresses: HashSet
= address_strings + .addresses + .into_iter() + .map(|address| { + address.parse().map_err(|error| { + Error::invalid_params(&format!("invalid address {address:?}: {error}")) + }) + }) + .collect::>()?; + + let request = zebra_state::ReadRequest::AddressBalance(addresses); + let response = state.oneshot(request).await.map_err(|error| Error { + code: ErrorCode::ServerError(0), + message: error.to_string(), + data: None, + })?; + + match response { + zebra_state::ReadResponse::AddressBalance(balance) => { + Ok(AddressBalance { balance }) + } + _ => unreachable!("Unexpected response from state service: {response:?}"), + } + } + .boxed() + } + fn send_raw_transaction( &self, raw_transaction_hex: String, @@ -657,6 +718,20 @@ pub struct GetBlockChainInfo { consensus: TipConsensusBranch, } +/// A wrapper type with a list of strings of addresses. +/// +/// This is used for the input parameter of [`Rpc::get_account_balance`]. +#[derive(Clone, Debug, Eq, PartialEq, Hash, serde::Deserialize)] +pub struct AddressStrings { + addresses: Vec, +} + +/// The transparent balance of a set of addresses. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, serde::Serialize)] +pub struct AddressBalance { + balance: Amount, +} + /// A hex-encoded [`ConsensusBranchId`] string. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)] struct ConsensusBranchIdHex(#[serde(with = "hex")] ConsensusBranchId); diff --git a/zebra-rpc/src/methods/tests/prop.rs b/zebra-rpc/src/methods/tests/prop.rs index 55c57a23..0146d8b4 100644 --- a/zebra-rpc/src/methods/tests/prop.rs +++ b/zebra-rpc/src/methods/tests/prop.rs @@ -2,14 +2,15 @@ use std::collections::HashSet; -use futures::FutureExt; +use futures::{join, FutureExt, TryFutureExt}; use hex::ToHex; use jsonrpc_core::{Error, ErrorCode}; -use proptest::prelude::*; +use proptest::{collection::vec, prelude::*}; use thiserror::Error; use tower::buffer::Buffer; use zebra_chain::{ + amount::{Amount, NonNegative}, block::{Block, Height}, chain_tip::{mock::MockChainTip, NoChainTip}, parameters::{ @@ -18,13 +19,16 @@ use zebra_chain::{ }, serialization::{ZcashDeserialize, ZcashSerialize}, transaction::{self, Transaction, UnminedTx, UnminedTxId}, + transparent, }; use zebra_node_services::mempool; use zebra_state::BoxError; use zebra_test::mock_service::MockService; -use super::super::{NetworkUpgradeStatus, Rpc, RpcImpl, SentTransactionHash}; +use super::super::{ + AddressBalance, AddressStrings, NetworkUpgradeStatus, Rpc, RpcImpl, SentTransactionHash, +}; proptest! { /// Test that when sending a raw transaction, it is received by the mempool service. @@ -304,8 +308,7 @@ proptest! { /// Make the mock mempool service return a list of transaction IDs, and check that the RPC call /// returns those IDs as hexadecimal strings. #[test] - fn mempool_transactions_are_sent_to_caller(transaction_ids in any::>()) - { + fn mempool_transactions_are_sent_to_caller(transaction_ids in any::>()) { let runtime = zebra_test::init_async(); let _guard = runtime.enter(); @@ -357,7 +360,9 @@ proptest! { /// Try to call `get_raw_transaction` using a string parameter that has at least one /// non-hexadecimal character, and check that it fails with an expected error. #[test] - fn get_raw_transaction_non_hexadecimal_string_results_in_an_error(non_hex_string in ".*[^0-9A-Fa-f].*") { + fn get_raw_transaction_non_hexadecimal_string_results_in_an_error( + non_hex_string in ".*[^0-9A-Fa-f].*", + ) { let runtime = zebra_test::init_async(); let _guard = runtime.enter(); @@ -409,7 +414,9 @@ proptest! { /// Try to call `get_raw_transaction` using random bytes that fail to deserialize as a /// transaction, and check that it fails with an expected error. #[test] - fn get_raw_transaction_invalid_transaction_results_in_an_error(random_bytes in any::>()) { + fn get_raw_transaction_invalid_transaction_results_in_an_error( + random_bytes in any::>(), + ) { let runtime = zebra_test::init_async(); let _guard = runtime.enter(); @@ -476,7 +483,10 @@ proptest! { ); let response = rpc.get_blockchain_info(); - prop_assert_eq!(&response.err().unwrap().message, "No Chain tip available yet"); + prop_assert_eq!( + &response.err().unwrap().message, + "No Chain tip available yet" + ); // The queue task should continue without errors or panics let rpc_tx_queue_task_result = rpc_tx_queue_task_handle.now_or_never(); @@ -529,8 +539,18 @@ proptest! { prop_assert_eq!(info.best_block_hash.0, block_hash); prop_assert!(info.estimated_height < Height::MAX.0); - prop_assert_eq!(info.consensus.chain_tip.0, NetworkUpgrade::current(network, block_height).branch_id().unwrap()); - prop_assert_eq!(info.consensus.next_block.0, NetworkUpgrade::current(network, (block_height + 1).unwrap()).branch_id().unwrap()); + prop_assert_eq!( + info.consensus.chain_tip.0, + NetworkUpgrade::current(network, block_height) + .branch_id() + .unwrap() + ); + prop_assert_eq!( + info.consensus.next_block.0, + NetworkUpgrade::current(network, (block_height + 1).unwrap()) + .branch_id() + .unwrap() + ); for u in info.upgrades { let mut status = NetworkUpgradeStatus::Active; @@ -539,10 +559,10 @@ proptest! { } prop_assert_eq!(u.1.status, status); } - }, + } Err(_) => { unreachable!("Test should never error with the data we are feeding it") - }, + } }; // The queue task should continue without errors or panics @@ -558,10 +578,133 @@ proptest! { })?; } + /// Test the `get_address_balance` RPC using an arbitrary set of addresses. + #[test] + fn queries_balance_for_valid_addresses( + network in any::(), + addresses in any::>(), + balance in any::>(), + ) { + let runtime = zebra_test::init_async(); + let _guard = runtime.enter(); + + let mut mempool = MockService::build().for_prop_tests(); + let mut state: MockService<_, _, _, BoxError> = MockService::build().for_prop_tests(); + + // Create a mocked `ChainTip` + let (chain_tip, _mock_chain_tip_sender) = MockChainTip::new(); + + // Prepare the list of addresses. + let address_strings = AddressStrings { + addresses: addresses + .iter() + .map(|address| address.to_string()) + .collect(), + }; + + tokio::time::pause(); + + // Start RPC with the mocked `ChainTip` + runtime.block_on(async move { + let (rpc, _rpc_tx_queue_task_handle) = RpcImpl::new( + "RPC test", + Buffer::new(mempool.clone(), 1), + Buffer::new(state.clone(), 1), + chain_tip, + network, + ); + + // Build the future to call the RPC + let call = rpc.get_address_balance(address_strings); + + // The RPC should perform a state query + let state_query = state + .expect_request(zebra_state::ReadRequest::AddressBalance(addresses)) + .map_ok(|responder| { + responder.respond(zebra_state::ReadResponse::AddressBalance(balance)) + }); + + // Await the RPC call and the state query + let (response, state_query_result) = join!(call, state_query); + + state_query_result?; + + // Check that response contains the expected balance + let received_balance = response?; + + prop_assert_eq!(received_balance, AddressBalance { balance }); + + // Check no further requests were made during this test + mempool.expect_no_requests().await?; + state.expect_no_requests().await?; + + Ok::<_, TestCaseError>(()) + })?; + } + + /// Test the `get_address_balance` RPC using an invalid list of addresses. + /// + /// An error should be returned. + #[test] + fn does_not_query_balance_for_invalid_addresses( + network in any::(), + at_least_one_invalid_address in vec(".*", 1..10), + ) { + let runtime = zebra_test::init_async(); + let _guard = runtime.enter(); + + prop_assume!(at_least_one_invalid_address + .iter() + .any(|string| string.parse::().is_err())); + + let mut mempool = MockService::build().for_prop_tests(); + let mut state: MockService<_, _, _, BoxError> = MockService::build().for_prop_tests(); + + // Create a mocked `ChainTip` + let (chain_tip, _mock_chain_tip_sender) = MockChainTip::new(); + + tokio::time::pause(); + + // Start RPC with the mocked `ChainTip` + runtime.block_on(async move { + let (rpc, _rpc_tx_queue_task_handle) = RpcImpl::new( + "RPC test", + Buffer::new(mempool.clone(), 1), + Buffer::new(state.clone(), 1), + chain_tip, + network, + ); + + let address_strings = AddressStrings { + addresses: at_least_one_invalid_address, + }; + + // Build the future to call the RPC + let result = rpc.get_address_balance(address_strings).await; + + // Check that the invalid addresses lead to an error + prop_assert!( + matches!( + result, + Err(Error { + code: ErrorCode::InvalidParams, + .. + }) + ), + "Result is not a server error: {result:?}" + ); + + // Check no requests were made during this test + mempool.expect_no_requests().await?; + state.expect_no_requests().await?; + + Ok::<_, TestCaseError>(()) + })?; + } + /// Test the queue functionality using `send_raw_transaction` #[test] - fn rpc_queue_main_loop(tx in any::()) - { + fn rpc_queue_main_loop(tx in any::()) { let runtime = zebra_test::init_async(); let _guard = runtime.enter(); @@ -627,7 +770,8 @@ proptest! { .respond(response); // now a retry will be sent to the mempool - let expected_request = mempool::Request::Queue(vec![mempool::Gossip::Tx(tx_unmined.clone())]); + let expected_request = + mempool::Request::Queue(vec![mempool::Gossip::Tx(tx_unmined.clone())]); let response = mempool::Response::Queued(vec![Ok(())]); mempool @@ -649,8 +793,7 @@ proptest! { /// Test we receive all transactions that are sent in a channel #[test] - fn rpc_queue_receives_all_transactions_from_channel(txs in any::<[Transaction; 2]>()) - { + fn rpc_queue_receives_all_transactions_from_channel(txs in any::<[Transaction; 2]>()) { let runtime = zebra_test::init_async(); let _guard = runtime.enter(); @@ -715,14 +858,17 @@ proptest! { // we use `expect_request_that` because we can't guarantee the state request order state - .expect_request_that(|request| matches!(request, zebra_state::ReadRequest::Transaction(_))) + .expect_request_that(|request| { + matches!(request, zebra_state::ReadRequest::Transaction(_)) + }) .await? .respond(response); } // each transaction will be retried for tx in txs.clone() { - let expected_request = mempool::Request::Queue(vec![mempool::Gossip::Tx(UnminedTx::from(tx))]); + let expected_request = + mempool::Request::Queue(vec![mempool::Gossip::Tx(UnminedTx::from(tx))]); let response = mempool::Response::Queued(vec![Ok(())]); mempool diff --git a/zebra-state/src/request.rs b/zebra-state/src/request.rs index f8bc64b9..55f10cce 100644 --- a/zebra-state/src/request.rs +++ b/zebra-state/src/request.rs @@ -1,6 +1,9 @@ //! State [`tower::Service`] request types. -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use zebra_chain::{ amount::NegativeAllowed, @@ -445,4 +448,9 @@ pub enum ReadRequest { /// Returned txids are in the order they appear in blocks, which ensures that they are topologically sorted /// (i.e. parent txids will appear before child txids). TransactionsByAddresses(Vec, block::Height, block::Height), + + /// Looks up the balance of a set of transparent addresses. + /// + /// Returns an [`Amount`] with the total balance of the set of addresses. + AddressBalance(HashSet), } diff --git a/zebra-state/src/response.rs b/zebra-state/src/response.rs index 6a0af14a..d2d25cd1 100644 --- a/zebra-state/src/response.rs +++ b/zebra-state/src/response.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use zebra_chain::{ + amount::{Amount, NonNegative}, block::{self, Block}, transaction::{Hash, Transaction}, transparent, @@ -57,4 +58,7 @@ pub enum ReadResponse { /// Response to [`ReadRequest::TransactionsByAddresses`] with the obtained transaction ids, /// in the order they appear in blocks. TransactionIds(Vec), + + /// Response to [`ReadRequest::AddressBalance`] with the total balance of the addresses. + AddressBalance(Amount), } diff --git a/zebra-state/src/service.rs b/zebra-state/src/service.rs index 0b790cee..3195a6e4 100644 --- a/zebra-state/src/service.rs +++ b/zebra-state/src/service.rs @@ -1014,6 +1014,27 @@ impl Service for ReadStateService { } .boxed() } + + // For the get_address_balance RPC. + ReadRequest::AddressBalance(addresses) => { + metrics::counter!( + "state.requests", + 1, + "service" => "read_state", + "type" => "address_balance", + ); + + let state = self.clone(); + + async move { + let balance = state.best_chain_receiver.with_watch_data(|best_chain| { + read::transparent_balance(best_chain, &state.db, addresses) + })?; + + Ok(ReadResponse::AddressBalance(balance)) + } + .boxed() + } } } }