From c1a2238f1a6258af765b03bbefd13da908fc15d1 Mon Sep 17 00:00:00 2001 From: Michael Sproul Date: Tue, 5 Nov 2019 15:46:52 +1100 Subject: [PATCH] Implement tree hash caching (#584) * Implement basic tree hash caching * Use spaces to indent top-level Cargo.toml * Optimize BLS tree hash by hashing bytes directly * Implement tree hash caching for validator registry * Persist BeaconState tree hash cache to disk * Address Paul's review comments --- Cargo.toml | 75 ++++---- beacon_node/store/src/impls/beacon_state.rs | 9 +- eth2/types/Cargo.toml | 1 + eth2/types/src/beacon_state.rs | 105 +++++++---- eth2/types/src/crosslink_committee.rs | 5 +- eth2/types/src/lib.rs | 1 + eth2/types/src/tree_hash_impls.rs | 129 ++++++++++++++ eth2/utils/bls/build.rs | 19 -- eth2/utils/bls/src/aggregate_signature.rs | 2 +- .../utils/bls/src/fake_aggregate_signature.rs | 2 +- eth2/utils/bls/src/fake_public_key.rs | 2 +- eth2/utils/bls/src/fake_signature.rs | 2 +- eth2/utils/bls/src/macros.rs | 21 ++- eth2/utils/bls/src/public_key.rs | 2 +- eth2/utils/bls/src/public_key_bytes.rs | 3 +- eth2/utils/bls/src/secret_key.rs | 2 +- eth2/utils/bls/src/signature.rs | 2 +- eth2/utils/bls/src/signature_bytes.rs | 8 +- eth2/utils/cached_tree_hash/Cargo.toml | 17 ++ eth2/utils/cached_tree_hash/src/cache.rs | 137 +++++++++++++++ eth2/utils/cached_tree_hash/src/impls.rs | 99 +++++++++++ eth2/utils/cached_tree_hash/src/lib.rs | 31 ++++ .../utils/cached_tree_hash/src/multi_cache.rs | 62 +++++++ eth2/utils/cached_tree_hash/src/test.rs | 147 ++++++++++++++++ eth2/utils/eth2_hashing/Cargo.toml | 9 +- eth2/utils/eth2_hashing/src/lib.rs | 38 ++++ eth2/utils/merkle_proof/src/lib.rs | 53 ++---- eth2/utils/tree_hash/Cargo.toml | 2 +- eth2/utils/tree_hash/benches/benches.rs | 56 ++++-- eth2/utils/tree_hash/src/impls.rs | 30 ---- eth2/utils/tree_hash/src/lib.rs | 5 +- eth2/utils/tree_hash/src/merkleize_padded.rs | 30 +--- eth2/utils/tree_hash_derive/src/lib.rs | 165 +++++++++++++++++- tests/ef_tests/Cargo.toml | 1 + tests/ef_tests/src/cases/ssz_generic.rs | 2 +- tests/ef_tests/src/cases/ssz_static.rs | 44 ++++- tests/ef_tests/src/handler.rs | 26 +++ tests/ef_tests/tests/tests.rs | 16 +- 38 files changed, 1112 insertions(+), 248 deletions(-) create mode 100644 eth2/types/src/tree_hash_impls.rs delete mode 100644 eth2/utils/bls/build.rs create mode 100644 eth2/utils/cached_tree_hash/Cargo.toml create mode 100644 eth2/utils/cached_tree_hash/src/cache.rs create mode 100644 eth2/utils/cached_tree_hash/src/impls.rs create mode 100644 eth2/utils/cached_tree_hash/src/lib.rs create mode 100644 eth2/utils/cached_tree_hash/src/multi_cache.rs create mode 100644 eth2/utils/cached_tree_hash/src/test.rs diff --git a/Cargo.toml b/Cargo.toml index 961615531..2edbced09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,44 +1,45 @@ [workspace] members = [ - "eth2/lmd_ghost", - "eth2/operation_pool", - "eth2/state_processing", - "eth2/types", - "eth2/utils/bls", - "eth2/utils/compare_fields", - "eth2/utils/compare_fields_derive", - "eth2/utils/eth2_config", - "eth2/utils/eth2_interop_keypairs", - "eth2/utils/logging", - "eth2/utils/eth2_hashing", - "eth2/utils/lighthouse_metrics", - "eth2/utils/lighthouse_bootstrap", - "eth2/utils/merkle_proof", - "eth2/utils/int_to_bytes", - "eth2/utils/serde_hex", - "eth2/utils/slot_clock", - "eth2/utils/ssz", - "eth2/utils/ssz_derive", - "eth2/utils/ssz_types", - "eth2/utils/swap_or_not_shuffle", - "eth2/utils/tree_hash", - "eth2/utils/tree_hash_derive", + "eth2/lmd_ghost", + "eth2/operation_pool", + "eth2/state_processing", + "eth2/types", + "eth2/utils/bls", + "eth2/utils/compare_fields", + "eth2/utils/compare_fields_derive", + "eth2/utils/eth2_config", + "eth2/utils/eth2_interop_keypairs", + "eth2/utils/logging", + "eth2/utils/eth2_hashing", + "eth2/utils/lighthouse_metrics", + "eth2/utils/lighthouse_bootstrap", + "eth2/utils/merkle_proof", + "eth2/utils/int_to_bytes", + "eth2/utils/serde_hex", + "eth2/utils/slot_clock", + "eth2/utils/ssz", + "eth2/utils/ssz_derive", + "eth2/utils/ssz_types", + "eth2/utils/swap_or_not_shuffle", + "eth2/utils/cached_tree_hash", + "eth2/utils/tree_hash", + "eth2/utils/tree_hash_derive", "eth2/utils/test_random_derive", - "beacon_node", - "beacon_node/store", - "beacon_node/client", - "beacon_node/rest_api", - "beacon_node/network", - "beacon_node/eth2-libp2p", + "beacon_node", + "beacon_node/store", + "beacon_node/client", + "beacon_node/rest_api", + "beacon_node/network", + "beacon_node/eth2-libp2p", "beacon_node/rpc", - "beacon_node/version", - "beacon_node/beacon_chain", - "beacon_node/websocket_server", - "tests/ef_tests", - "lcli", - "protos", - "validator_client", - "account_manager", + "beacon_node/version", + "beacon_node/beacon_chain", + "beacon_node/websocket_server", + "tests/ef_tests", + "lcli", + "protos", + "validator_client", + "account_manager", ] [patch] diff --git a/beacon_node/store/src/impls/beacon_state.rs b/beacon_node/store/src/impls/beacon_state.rs index 69e83cd63..2113d35bd 100644 --- a/beacon_node/store/src/impls/beacon_state.rs +++ b/beacon_node/store/src/impls/beacon_state.rs @@ -2,13 +2,15 @@ use crate::*; use ssz::{Decode, DecodeError, Encode}; use ssz_derive::{Decode, Encode}; use std::convert::TryInto; -use types::beacon_state::{CommitteeCache, CACHED_EPOCHS}; +use types::beacon_state::{BeaconTreeHashCache, CommitteeCache, CACHED_EPOCHS}; /// A container for storing `BeaconState` components. +// TODO: would be more space efficient with the caches stored separately and referenced by hash #[derive(Encode, Decode)] struct StorageContainer { state_bytes: Vec, committee_caches_bytes: Vec>, + tree_hash_cache_bytes: Vec, } impl StorageContainer { @@ -20,9 +22,12 @@ impl StorageContainer { committee_caches_bytes.push(cache.as_ssz_bytes()); } + let tree_hash_cache_bytes = state.tree_hash_cache.as_ssz_bytes(); + Self { state_bytes: state.as_ssz_bytes(), committee_caches_bytes, + tree_hash_cache_bytes, } } } @@ -43,6 +48,8 @@ impl TryInto> for StorageContainer { state.committee_caches[i] = CommitteeCache::from_ssz_bytes(bytes)?; } + state.tree_hash_cache = BeaconTreeHashCache::from_ssz_bytes(&self.tree_hash_cache_bytes)?; + Ok(state) } } diff --git a/eth2/types/Cargo.toml b/eth2/types/Cargo.toml index 9123ca6b3..e3138b26c 100644 --- a/eth2/types/Cargo.toml +++ b/eth2/types/Cargo.toml @@ -29,6 +29,7 @@ test_random_derive = { path = "../utils/test_random_derive" } tree_hash = "0.1.0" tree_hash_derive = "0.2" rand_xorshift = "0.2.0" +cached_tree_hash = { path = "../utils/cached_tree_hash" } [dev-dependencies] env_logger = "0.7.1" diff --git a/eth2/types/src/beacon_state.rs b/eth2/types/src/beacon_state.rs index f64deb38a..2aa805808 100644 --- a/eth2/types/src/beacon_state.rs +++ b/eth2/types/src/beacon_state.rs @@ -2,6 +2,7 @@ use self::committee_cache::get_active_validator_indices; use self::exit_cache::ExitCache; use crate::test_utils::TestRandom; use crate::*; +use cached_tree_hash::{CachedTreeHash, MultiTreeHashCache, TreeHashCache}; use compare_fields_derive::CompareFields; use eth2_hashing::hash; use int_to_bytes::{int_to_bytes32, int_to_bytes8}; @@ -12,7 +13,7 @@ use ssz_derive::{Decode, Encode}; use ssz_types::{typenum::Unsigned, BitVector, FixedVector}; use test_random_derive::TestRandom; use tree_hash::TreeHash; -use tree_hash_derive::TreeHash; +use tree_hash_derive::{CachedTreeHash, TreeHash}; pub use self::committee_cache::CommitteeCache; pub use eth_spec::*; @@ -57,6 +58,7 @@ pub enum Error { RelativeEpochError(RelativeEpochError), CommitteeCacheUninitialized(RelativeEpoch), SszTypesError(ssz_types::Error), + CachedTreeHashError(cached_tree_hash::Error), } /// Control whether an epoch-indexed field can be indexed at the next epoch or not. @@ -75,6 +77,26 @@ impl AllowNextEpoch { } } +#[derive(Debug, PartialEq, Clone, Default, Encode, Decode)] +pub struct BeaconTreeHashCache { + initialized: bool, + block_roots: TreeHashCache, + state_roots: TreeHashCache, + historical_roots: TreeHashCache, + validators: MultiTreeHashCache, + balances: TreeHashCache, + randao_mixes: TreeHashCache, + active_index_roots: TreeHashCache, + compact_committees_roots: TreeHashCache, + slashings: TreeHashCache, +} + +impl BeaconTreeHashCache { + pub fn is_initialized(&self) -> bool { + self.initialized + } +} + /// The state of the `BeaconChain` at some slot. /// /// Spec v0.8.0 @@ -88,9 +110,11 @@ impl AllowNextEpoch { Encode, Decode, TreeHash, + CachedTreeHash, CompareFields, )] #[serde(bound = "T: EthSpec")] +#[cached_tree_hash(type = "BeaconTreeHashCache")] pub struct BeaconState where T: EthSpec, @@ -103,9 +127,12 @@ where // History pub latest_block_header: BeaconBlockHeader, #[compare_fields(as_slice)] + #[cached_tree_hash(block_roots)] pub block_roots: FixedVector, #[compare_fields(as_slice)] + #[cached_tree_hash(state_roots)] pub state_roots: FixedVector, + #[cached_tree_hash(historical_roots)] pub historical_roots: VariableList, // Ethereum 1.0 chain data @@ -115,19 +142,25 @@ where // Registry #[compare_fields(as_slice)] + #[cached_tree_hash(validators)] pub validators: VariableList, #[compare_fields(as_slice)] + #[cached_tree_hash(balances)] pub balances: VariableList, // Shuffling pub start_shard: u64, + #[cached_tree_hash(randao_mixes)] pub randao_mixes: FixedVector, #[compare_fields(as_slice)] + #[cached_tree_hash(active_index_roots)] pub active_index_roots: FixedVector, #[compare_fields(as_slice)] + #[cached_tree_hash(compact_committees_roots)] pub compact_committees_roots: FixedVector, // Slashings + #[cached_tree_hash(slashings)] pub slashings: FixedVector, // Attestations @@ -164,6 +197,12 @@ where #[tree_hash(skip_hashing)] #[test_random(default)] pub exit_cache: ExitCache, + #[serde(skip_serializing, skip_deserializing)] + #[ssz(skip_serializing)] + #[ssz(skip_deserializing)] + #[tree_hash(skip_hashing)] + #[test_random(default)] + pub tree_hash_cache: BeaconTreeHashCache, } impl BeaconState { @@ -225,6 +264,7 @@ impl BeaconState { ], pubkey_cache: PubkeyCache::default(), exit_cache: ExitCache::default(), + tree_hash_cache: BeaconTreeHashCache::default(), } } @@ -825,7 +865,7 @@ impl BeaconState { self.build_committee_cache(RelativeEpoch::Current, spec)?; self.build_committee_cache(RelativeEpoch::Next, spec)?; self.update_pubkey_cache()?; - self.update_tree_hash_cache()?; + self.build_tree_hash_cache()?; self.exit_cache.build_from_registry(&self.validators, spec); Ok(()) @@ -936,41 +976,40 @@ impl BeaconState { self.pubkey_cache = PubkeyCache::default() } - /// Update the tree hash cache, building it for the first time if it is empty. - /// - /// Returns the `tree_hash_root` resulting from the update. This root can be considered the - /// canonical root of `self`. - /// - /// ## Note - /// - /// Cache not currently implemented, just performs a full tree hash. - pub fn update_tree_hash_cache(&mut self) -> Result { - // TODO(#440): re-enable cached tree hash - Ok(Hash256::from_slice(&self.tree_hash_root())) + /// Initialize but don't fill the tree hash cache, if it isn't already initialized. + pub fn initialize_tree_hash_cache(&mut self) { + if !self.tree_hash_cache.initialized { + self.tree_hash_cache = Self::new_tree_hash_cache(); + } } - /// Returns the tree hash root determined by the last execution of `self.update_tree_hash_cache(..)`. + /// Build and update the tree hash cache if it isn't already initialized. + pub fn build_tree_hash_cache(&mut self) -> Result<(), Error> { + self.update_tree_hash_cache().map(|_| ()) + } + + /// Build the tree hash cache, with blatant disregard for any existing cache. + pub fn force_build_tree_hash_cache(&mut self) -> Result<(), Error> { + self.tree_hash_cache.initialized = false; + self.build_tree_hash_cache() + } + + /// Compute the tree hash root of the state using the tree hash cache. /// - /// Note: does _not_ update the cache and may return an outdated root. - /// - /// Returns an error if the cache is not initialized or if an error is encountered during the - /// cache update. - /// - /// ## Note - /// - /// Cache not currently implemented, just performs a full tree hash. - pub fn cached_tree_hash_root(&self) -> Result { - // TODO(#440): re-enable cached tree hash - Ok(Hash256::from_slice(&self.tree_hash_root())) + /// Initialize the tree hash cache if it isn't already initialized. + pub fn update_tree_hash_cache(&mut self) -> Result { + self.initialize_tree_hash_cache(); + + let mut cache = std::mem::replace(&mut self.tree_hash_cache, <_>::default()); + let result = self.recalculate_tree_hash_root(&mut cache); + std::mem::replace(&mut self.tree_hash_cache, cache); + + Ok(result?) } /// Completely drops the tree hash cache, replacing it with a new, empty cache. - /// - /// ## Note - /// - /// Cache not currently implemented, is a no-op. pub fn drop_tree_hash_cache(&mut self) { - // TODO(#440): re-enable cached tree hash + self.tree_hash_cache = BeaconTreeHashCache::default(); } } @@ -985,3 +1024,9 @@ impl From for Error { Error::SszTypesError(e) } } + +impl From for Error { + fn from(e: cached_tree_hash::Error) -> Error { + Error::CachedTreeHashError(e) + } +} diff --git a/eth2/types/src/crosslink_committee.rs b/eth2/types/src/crosslink_committee.rs index 00c4bebc0..0f7a401ca 100644 --- a/eth2/types/src/crosslink_committee.rs +++ b/eth2/types/src/crosslink_committee.rs @@ -1,7 +1,6 @@ use crate::*; -use tree_hash_derive::TreeHash; -#[derive(Default, Clone, Debug, PartialEq, TreeHash)] +#[derive(Default, Clone, Debug, PartialEq)] pub struct CrosslinkCommittee<'a> { pub slot: Slot, pub shard: Shard, @@ -18,7 +17,7 @@ impl<'a> CrosslinkCommittee<'a> { } } -#[derive(Default, Clone, Debug, PartialEq, TreeHash)] +#[derive(Default, Clone, Debug, PartialEq)] pub struct OwnedCrosslinkCommittee { pub slot: Slot, pub shard: Shard, diff --git a/eth2/types/src/lib.rs b/eth2/types/src/lib.rs index 8f9c07b0d..fa23f9c1c 100644 --- a/eth2/types/src/lib.rs +++ b/eth2/types/src/lib.rs @@ -38,6 +38,7 @@ pub mod slot_epoch_macros; pub mod relative_epoch; pub mod slot_epoch; pub mod slot_height; +mod tree_hash_impls; pub mod validator; use ethereum_types::{H160, H256, U256}; diff --git a/eth2/types/src/tree_hash_impls.rs b/eth2/types/src/tree_hash_impls.rs new file mode 100644 index 000000000..2d652c475 --- /dev/null +++ b/eth2/types/src/tree_hash_impls.rs @@ -0,0 +1,129 @@ +//! This module contains custom implementations of `CachedTreeHash` for ETH2-specific types. +//! +//! It makes some assumptions about the layouts and update patterns of other structs in this +//! crate, and should be updated carefully whenever those structs are changed. +use crate::{Hash256, Validator}; +use cached_tree_hash::{int_log, CachedTreeHash, Error, TreeHashCache}; +use tree_hash::TreeHash; + +/// Number of struct fields on `Validator`. +const NUM_VALIDATOR_FIELDS: usize = 8; + +impl CachedTreeHash for Validator { + fn new_tree_hash_cache() -> TreeHashCache { + TreeHashCache::new(int_log(NUM_VALIDATOR_FIELDS)) + } + + /// Efficiently tree hash a `Validator`, assuming it was updated by a valid state transition. + /// + /// Specifically, we assume that the `pubkey` and `withdrawal_credentials` fields are constant. + fn recalculate_tree_hash_root(&self, cache: &mut TreeHashCache) -> Result { + // If the cache is empty, hash every field to fill it. + if cache.leaves().is_empty() { + return cache.recalculate_merkle_root(field_tree_hash_iter(self)); + } + + // Otherwise just check the fields which might have changed. + let dirty_indices = cache + .leaves() + .iter_mut() + .enumerate() + .flat_map(|(i, leaf)| { + // Fields pubkey and withdrawal_credentials are constant + if i == 0 || i == 1 { + None + } else { + let new_tree_hash = field_tree_hash_by_index(self, i); + if leaf.as_bytes() != &new_tree_hash[..] { + leaf.assign_from_slice(&new_tree_hash); + Some(i) + } else { + None + } + } + }) + .collect(); + + cache.update_merkle_root(dirty_indices) + } +} + +/// Get the tree hash root of a validator field by its position/index in the struct. +fn field_tree_hash_by_index(v: &Validator, field_idx: usize) -> Vec { + match field_idx { + 0 => v.pubkey.tree_hash_root(), + 1 => v.withdrawal_credentials.tree_hash_root(), + 2 => v.effective_balance.tree_hash_root(), + 3 => v.slashed.tree_hash_root(), + 4 => v.activation_eligibility_epoch.tree_hash_root(), + 5 => v.activation_epoch.tree_hash_root(), + 6 => v.exit_epoch.tree_hash_root(), + 7 => v.withdrawable_epoch.tree_hash_root(), + _ => panic!( + "Validator type only has {} fields, {} out of bounds", + NUM_VALIDATOR_FIELDS, field_idx + ), + } +} + +/// Iterator over the tree hash roots of `Validator` fields. +fn field_tree_hash_iter<'a>( + v: &'a Validator, +) -> impl Iterator + ExactSizeIterator + 'a { + (0..NUM_VALIDATOR_FIELDS) + .map(move |i| field_tree_hash_by_index(v, i)) + .map(|tree_hash_root| { + let mut res = [0; 32]; + res.copy_from_slice(&tree_hash_root[0..32]); + res + }) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::test_utils::TestRandom; + use crate::Epoch; + use rand::SeedableRng; + use rand_xorshift::XorShiftRng; + + fn test_validator_tree_hash(v: &Validator) { + let mut cache = Validator::new_tree_hash_cache(); + // With a fresh cache + assert_eq!( + &v.tree_hash_root()[..], + v.recalculate_tree_hash_root(&mut cache).unwrap().as_bytes(), + "{:?}", + v + ); + // With a completely up-to-date cache + assert_eq!( + &v.tree_hash_root()[..], + v.recalculate_tree_hash_root(&mut cache).unwrap().as_bytes(), + "{:?}", + v + ); + } + + #[test] + fn default_validator() { + test_validator_tree_hash(&Validator::default()); + } + + #[test] + fn zeroed_validator() { + let mut v = Validator::default(); + v.activation_eligibility_epoch = Epoch::from(0u64); + v.activation_epoch = Epoch::from(0u64); + test_validator_tree_hash(&v); + } + + #[test] + fn random_validators() { + let mut rng = XorShiftRng::from_seed([0xf1; 16]); + let num_validators = 1000; + (0..num_validators) + .map(|_| Validator::random_for_test(&mut rng)) + .for_each(|v| test_validator_tree_hash(&v)); + } +} diff --git a/eth2/utils/bls/build.rs b/eth2/utils/bls/build.rs deleted file mode 100644 index 7f08a1ed5..000000000 --- a/eth2/utils/bls/build.rs +++ /dev/null @@ -1,19 +0,0 @@ -// This build script is symlinked from each project that requires BLS's "fake crypto", -// so that the `fake_crypto` feature of every sub-crate can be turned on by running -// with FAKE_CRYPTO=1 from the top-level workspace. -// At some point in the future it might be possible to do: -// $ cargo test --all --release --features fake_crypto -// but at the present time this doesn't work. -// Related: https://github.com/rust-lang/cargo/issues/5364 -fn main() { - if let Ok(fake_crypto) = std::env::var("FAKE_CRYPTO") { - if fake_crypto == "1" { - println!("cargo:rustc-cfg=feature=\"fake_crypto\""); - println!("cargo:rerun-if-env-changed=FAKE_CRYPTO"); - println!( - "cargo:warning=[{}]: Compiled with fake BLS cryptography. DO NOT USE, TESTING ONLY", - std::env::var("CARGO_PKG_NAME").unwrap() - ); - } - } -} diff --git a/eth2/utils/bls/src/aggregate_signature.rs b/eth2/utils/bls/src/aggregate_signature.rs index e80c1b100..5a081c943 100644 --- a/eth2/utils/bls/src/aggregate_signature.rs +++ b/eth2/utils/bls/src/aggregate_signature.rs @@ -155,7 +155,7 @@ impl_ssz!( "AggregateSignature" ); -impl_tree_hash!(AggregateSignature, U96); +impl_tree_hash!(AggregateSignature, BLS_AGG_SIG_BYTE_SIZE); impl Serialize for AggregateSignature { /// Serde serialization is compliant the Ethereum YAML test format. diff --git a/eth2/utils/bls/src/fake_aggregate_signature.rs b/eth2/utils/bls/src/fake_aggregate_signature.rs index 7911bb57a..52495a76e 100644 --- a/eth2/utils/bls/src/fake_aggregate_signature.rs +++ b/eth2/utils/bls/src/fake_aggregate_signature.rs @@ -93,7 +93,7 @@ impl_ssz!( "FakeAggregateSignature" ); -impl_tree_hash!(FakeAggregateSignature, U96); +impl_tree_hash!(FakeAggregateSignature, BLS_AGG_SIG_BYTE_SIZE); impl Serialize for FakeAggregateSignature { fn serialize(&self, serializer: S) -> Result diff --git a/eth2/utils/bls/src/fake_public_key.rs b/eth2/utils/bls/src/fake_public_key.rs index 82b1c707f..f9440d86d 100644 --- a/eth2/utils/bls/src/fake_public_key.rs +++ b/eth2/utils/bls/src/fake_public_key.rs @@ -102,7 +102,7 @@ impl default::Default for FakePublicKey { impl_ssz!(FakePublicKey, BLS_PUBLIC_KEY_BYTE_SIZE, "FakePublicKey"); -impl_tree_hash!(FakePublicKey, U48); +impl_tree_hash!(FakePublicKey, BLS_PUBLIC_KEY_BYTE_SIZE); impl Serialize for FakePublicKey { fn serialize(&self, serializer: S) -> Result diff --git a/eth2/utils/bls/src/fake_signature.rs b/eth2/utils/bls/src/fake_signature.rs index 6e34a518c..3ece5e87b 100644 --- a/eth2/utils/bls/src/fake_signature.rs +++ b/eth2/utils/bls/src/fake_signature.rs @@ -91,7 +91,7 @@ impl FakeSignature { impl_ssz!(FakeSignature, BLS_SIG_BYTE_SIZE, "FakeSignature"); -impl_tree_hash!(FakeSignature, U96); +impl_tree_hash!(FakeSignature, BLS_SIG_BYTE_SIZE); impl Serialize for FakeSignature { fn serialize(&self, serializer: S) -> Result diff --git a/eth2/utils/bls/src/macros.rs b/eth2/utils/bls/src/macros.rs index e8bd3dd04..4acf185f0 100644 --- a/eth2/utils/bls/src/macros.rs +++ b/eth2/utils/bls/src/macros.rs @@ -42,7 +42,7 @@ macro_rules! impl_ssz { } macro_rules! impl_tree_hash { - ($type: ty, $byte_size: ident) => { + ($type: ty, $byte_size: expr) => { impl tree_hash::TreeHash for $type { fn tree_hash_type() -> tree_hash::TreeHashType { tree_hash::TreeHashType::Vector @@ -57,16 +57,19 @@ macro_rules! impl_tree_hash { } fn tree_hash_root(&self) -> Vec { - let vector: ssz_types::FixedVector = - ssz_types::FixedVector::from(self.as_ssz_bytes()); - vector.tree_hash_root() + // We could use the tree hash implementation for `FixedVec`, + // but benchmarks have show that to be at least 15% slower because of the + // unnecessary copying and allocation (one Vec per byte) + let values_per_chunk = tree_hash::BYTES_PER_CHUNK; + let minimum_chunk_count = ($byte_size + values_per_chunk - 1) / values_per_chunk; + tree_hash::merkle_root(&self.as_ssz_bytes(), minimum_chunk_count) } } }; } macro_rules! bytes_struct { - ($name: ident, $type: ty, $byte_size: expr, $small_name: expr, $ssz_type_size: ident, + ($name: ident, $type: ty, $byte_size: expr, $small_name: expr, $type_str: expr, $byte_size_str: expr) => { #[doc = "Stores `"] #[doc = $byte_size_str] @@ -82,9 +85,9 @@ macro_rules! bytes_struct { #[derive(Clone)] pub struct $name([u8; $byte_size]); }; - ($name: ident, $type: ty, $byte_size: expr, $small_name: expr, $ssz_type_size: ident) => { - bytes_struct!($name, $type, $byte_size, $small_name, $ssz_type_size, stringify!($type), - stringify!($byte_size)); + ($name: ident, $type: ty, $byte_size: expr, $small_name: expr) => { + bytes_struct!($name, $type, $byte_size, $small_name, stringify!($type), + stringify!($byte_size)); impl $name { pub fn from_bytes(bytes: &[u8]) -> Result { @@ -144,7 +147,7 @@ macro_rules! bytes_struct { impl_ssz!($name, $byte_size, "$type"); - impl_tree_hash!($name, $ssz_type_size); + impl_tree_hash!($name, $byte_size); impl serde::ser::Serialize for $name { /// Serde serialization is compliant the Ethereum YAML test format. diff --git a/eth2/utils/bls/src/public_key.rs b/eth2/utils/bls/src/public_key.rs index 4b5abb58e..87204fae1 100644 --- a/eth2/utils/bls/src/public_key.rs +++ b/eth2/utils/bls/src/public_key.rs @@ -94,7 +94,7 @@ impl default::Default for PublicKey { impl_ssz!(PublicKey, BLS_PUBLIC_KEY_BYTE_SIZE, "PublicKey"); -impl_tree_hash!(PublicKey, U48); +impl_tree_hash!(PublicKey, BLS_PUBLIC_KEY_BYTE_SIZE); impl Serialize for PublicKey { fn serialize(&self, serializer: S) -> Result diff --git a/eth2/utils/bls/src/public_key_bytes.rs b/eth2/utils/bls/src/public_key_bytes.rs index afdbcb270..528ef8254 100644 --- a/eth2/utils/bls/src/public_key_bytes.rs +++ b/eth2/utils/bls/src/public_key_bytes.rs @@ -6,8 +6,7 @@ bytes_struct!( PublicKeyBytes, PublicKey, BLS_PUBLIC_KEY_BYTE_SIZE, - "public key", - U48 + "public key" ); #[cfg(test)] diff --git a/eth2/utils/bls/src/secret_key.rs b/eth2/utils/bls/src/secret_key.rs index d9ada7333..6e39cace3 100644 --- a/eth2/utils/bls/src/secret_key.rs +++ b/eth2/utils/bls/src/secret_key.rs @@ -49,7 +49,7 @@ impl SecretKey { impl_ssz!(SecretKey, BLS_SECRET_KEY_BYTE_SIZE, "SecretKey"); -impl_tree_hash!(SecretKey, U48); +impl_tree_hash!(SecretKey, BLS_SECRET_KEY_BYTE_SIZE); impl Serialize for SecretKey { fn serialize(&self, serializer: S) -> Result diff --git a/eth2/utils/bls/src/signature.rs b/eth2/utils/bls/src/signature.rs index 7a2bc6051..64f306b30 100644 --- a/eth2/utils/bls/src/signature.rs +++ b/eth2/utils/bls/src/signature.rs @@ -108,7 +108,7 @@ impl Signature { impl_ssz!(Signature, BLS_SIG_BYTE_SIZE, "Signature"); -impl_tree_hash!(Signature, U96); +impl_tree_hash!(Signature, BLS_SIG_BYTE_SIZE); impl Serialize for Signature { /// Serde serialization is compliant the Ethereum YAML test format. diff --git a/eth2/utils/bls/src/signature_bytes.rs b/eth2/utils/bls/src/signature_bytes.rs index b89c0f0d1..bfec269b0 100644 --- a/eth2/utils/bls/src/signature_bytes.rs +++ b/eth2/utils/bls/src/signature_bytes.rs @@ -2,13 +2,7 @@ use ssz::{Decode, DecodeError, Encode}; use super::{Signature, BLS_SIG_BYTE_SIZE}; -bytes_struct!( - SignatureBytes, - Signature, - BLS_SIG_BYTE_SIZE, - "signature", - U96 -); +bytes_struct!(SignatureBytes, Signature, BLS_SIG_BYTE_SIZE, "signature"); #[cfg(test)] mod tests { diff --git a/eth2/utils/cached_tree_hash/Cargo.toml b/eth2/utils/cached_tree_hash/Cargo.toml new file mode 100644 index 000000000..5ed95c78d --- /dev/null +++ b/eth2/utils/cached_tree_hash/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "cached_tree_hash" +version = "0.1.0" +authors = ["Michael Sproul "] +edition = "2018" + +[dependencies] +ethereum-types = "0.8" +eth2_ssz_types = { path = "../ssz_types" } +eth2_hashing = "0.1" +eth2_ssz_derive = "0.1.0" +eth2_ssz = "0.1.2" +tree_hash = "0.1" + +[dev-dependencies] +quickcheck = "0.9" +quickcheck_macros = "0.8" diff --git a/eth2/utils/cached_tree_hash/src/cache.rs b/eth2/utils/cached_tree_hash/src/cache.rs new file mode 100644 index 000000000..4a5d650fb --- /dev/null +++ b/eth2/utils/cached_tree_hash/src/cache.rs @@ -0,0 +1,137 @@ +use crate::{Error, Hash256}; +use eth2_hashing::{hash_concat, ZERO_HASHES}; +use ssz_derive::{Decode, Encode}; +use tree_hash::BYTES_PER_CHUNK; + +/// Sparse Merkle tree suitable for tree hashing vectors and lists. +#[derive(Debug, PartialEq, Clone, Default, Encode, Decode)] +pub struct TreeHashCache { + /// Depth is such that the tree has a capacity for 2^depth leaves + depth: usize, + /// Sparse layers. + /// + /// The leaves are contained in `self.layers[self.depth]`, and each other layer `i` + /// contains the parents of the nodes in layer `i + 1`. + layers: Vec>, +} + +impl TreeHashCache { + /// Create a new cache with the given `depth`, but no actual content. + pub fn new(depth: usize) -> Self { + TreeHashCache { + depth, + layers: vec![vec![]; depth + 1], + } + } + + /// Compute the updated Merkle root for the given `leaves`. + pub fn recalculate_merkle_root( + &mut self, + leaves: impl Iterator + ExactSizeIterator, + ) -> Result { + let dirty_indices = self.update_leaves(leaves)?; + self.update_merkle_root(dirty_indices) + } + + /// Phase 1 of the algorithm: compute the indices of all dirty leaves. + pub fn update_leaves( + &mut self, + mut leaves: impl Iterator + ExactSizeIterator, + ) -> Result, Error> { + let new_leaf_count = leaves.len(); + + if new_leaf_count < self.leaves().len() { + return Err(Error::CannotShrink); + } else if new_leaf_count > 2usize.pow(self.depth as u32) { + return Err(Error::TooManyLeaves); + } + + // Update the existing leaves + let mut dirty = self + .leaves() + .iter_mut() + .enumerate() + .zip(&mut leaves) + .flat_map(|((i, leaf), new_leaf)| { + if leaf.as_bytes() != new_leaf { + leaf.assign_from_slice(&new_leaf); + Some(i) + } else { + None + } + }) + .collect::>(); + + // Push the rest of the new leaves (if any) + dirty.extend(self.leaves().len()..new_leaf_count); + self.leaves() + .extend(leaves.map(|l| Hash256::from_slice(&l))); + + Ok(dirty) + } + + /// Phase 2: propagate changes upwards from the leaves of the tree, and compute the root. + /// + /// Returns an error if `dirty_indices` is inconsistent with the cache. + pub fn update_merkle_root(&mut self, mut dirty_indices: Vec) -> Result { + if dirty_indices.is_empty() { + return Ok(self.root()); + } + + let mut depth = self.depth; + + while depth > 0 { + let new_dirty_indices = lift_dirty(&dirty_indices); + + for &idx in &new_dirty_indices { + let left_idx = 2 * idx; + let right_idx = left_idx + 1; + + let left = self.layers[depth][left_idx]; + let right = self.layers[depth] + .get(right_idx) + .copied() + .unwrap_or_else(|| Hash256::from_slice(&ZERO_HASHES[self.depth - depth])); + + let new_hash = hash_concat(left.as_bytes(), right.as_bytes()); + + match self.layers[depth - 1].get_mut(idx) { + Some(hash) => { + hash.assign_from_slice(&new_hash); + } + None => { + // Parent layer should already contain nodes for all non-dirty indices + if idx != self.layers[depth - 1].len() { + return Err(Error::CacheInconsistent); + } + self.layers[depth - 1].push(Hash256::from_slice(&new_hash)); + } + } + } + + dirty_indices = new_dirty_indices; + depth -= 1; + } + + Ok(self.root()) + } + + /// Get the root of this cache, without doing any updates/computation. + pub fn root(&self) -> Hash256 { + self.layers[0] + .get(0) + .copied() + .unwrap_or_else(|| Hash256::from_slice(&ZERO_HASHES[self.depth])) + } + + pub fn leaves(&mut self) -> &mut Vec { + &mut self.layers[self.depth] + } +} + +/// Compute the dirty indices for one layer up. +fn lift_dirty(dirty_indices: &[usize]) -> Vec { + let mut new_dirty = dirty_indices.iter().map(|i| *i / 2).collect::>(); + new_dirty.dedup(); + new_dirty +} diff --git a/eth2/utils/cached_tree_hash/src/impls.rs b/eth2/utils/cached_tree_hash/src/impls.rs new file mode 100644 index 000000000..c5bc18120 --- /dev/null +++ b/eth2/utils/cached_tree_hash/src/impls.rs @@ -0,0 +1,99 @@ +use crate::{CachedTreeHash, Error, Hash256, TreeHashCache}; +use ssz_types::{typenum::Unsigned, FixedVector, VariableList}; +use std::mem::size_of; +use tree_hash::{mix_in_length, BYTES_PER_CHUNK}; + +/// Compute ceil(log(n)) +/// +/// Smallest number of bits d so that n <= 2^d +pub fn int_log(n: usize) -> usize { + match n.checked_next_power_of_two() { + Some(x) => x.trailing_zeros() as usize, + None => 8 * std::mem::size_of::(), + } +} + +pub fn hash256_iter<'a>( + values: &'a [Hash256], +) -> impl Iterator + ExactSizeIterator + 'a { + values.iter().copied().map(Hash256::to_fixed_bytes) +} + +pub fn u64_iter<'a>( + values: &'a [u64], +) -> impl Iterator + ExactSizeIterator + 'a { + let type_size = size_of::(); + let vals_per_chunk = BYTES_PER_CHUNK / type_size; + values.chunks(vals_per_chunk).map(move |xs| { + xs.iter().map(|x| x.to_le_bytes()).enumerate().fold( + [0; BYTES_PER_CHUNK], + |mut chunk, (i, x_bytes)| { + chunk[i * type_size..(i + 1) * type_size].copy_from_slice(&x_bytes); + chunk + }, + ) + }) +} + +impl CachedTreeHash for FixedVector { + fn new_tree_hash_cache() -> TreeHashCache { + TreeHashCache::new(int_log(N::to_usize())) + } + + fn recalculate_tree_hash_root(&self, cache: &mut TreeHashCache) -> Result { + cache.recalculate_merkle_root(hash256_iter(&self)) + } +} + +impl CachedTreeHash for FixedVector { + fn new_tree_hash_cache() -> TreeHashCache { + let vals_per_chunk = BYTES_PER_CHUNK / size_of::(); + TreeHashCache::new(int_log(N::to_usize() / vals_per_chunk)) + } + + fn recalculate_tree_hash_root(&self, cache: &mut TreeHashCache) -> Result { + cache.recalculate_merkle_root(u64_iter(&self)) + } +} + +impl CachedTreeHash for VariableList { + fn new_tree_hash_cache() -> TreeHashCache { + TreeHashCache::new(int_log(N::to_usize())) + } + + fn recalculate_tree_hash_root(&self, cache: &mut TreeHashCache) -> Result { + Ok(Hash256::from_slice(&mix_in_length( + cache + .recalculate_merkle_root(hash256_iter(&self))? + .as_bytes(), + self.len(), + ))) + } +} + +impl CachedTreeHash for VariableList { + fn new_tree_hash_cache() -> TreeHashCache { + let vals_per_chunk = BYTES_PER_CHUNK / size_of::(); + TreeHashCache::new(int_log(N::to_usize() / vals_per_chunk)) + } + + fn recalculate_tree_hash_root(&self, cache: &mut TreeHashCache) -> Result { + Ok(Hash256::from_slice(&mix_in_length( + cache.recalculate_merkle_root(u64_iter(&self))?.as_bytes(), + self.len(), + ))) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_int_log() { + for i in 0..63 { + assert_eq!(int_log(2usize.pow(i)), i as usize); + } + assert_eq!(int_log(10), 4); + } +} diff --git a/eth2/utils/cached_tree_hash/src/lib.rs b/eth2/utils/cached_tree_hash/src/lib.rs new file mode 100644 index 000000000..cc47ab21f --- /dev/null +++ b/eth2/utils/cached_tree_hash/src/lib.rs @@ -0,0 +1,31 @@ +mod cache; +mod impls; +mod multi_cache; +#[cfg(test)] +mod test; + +pub use crate::cache::TreeHashCache; +pub use crate::impls::int_log; +pub use crate::multi_cache::MultiTreeHashCache; +use ethereum_types::H256 as Hash256; +use tree_hash::TreeHash; + +#[derive(Debug, PartialEq)] +pub enum Error { + /// Attempting to provide more than 2^depth leaves to a Merkle tree is disallowed. + TooManyLeaves, + /// Shrinking a Merkle tree cache by providing it with less leaves than it currently has is + /// disallowed (for simplicity). + CannotShrink, + /// Cache is inconsistent with the list of dirty indices provided. + CacheInconsistent, +} + +/// Trait for types which can make use of a cache to accelerate calculation of their tree hash root. +pub trait CachedTreeHash: TreeHash { + /// Create a new cache appropriate for use with values of this type. + fn new_tree_hash_cache() -> Cache; + + /// Update the cache and use it to compute the tree hash root for `self`. + fn recalculate_tree_hash_root(&self, cache: &mut Cache) -> Result; +} diff --git a/eth2/utils/cached_tree_hash/src/multi_cache.rs b/eth2/utils/cached_tree_hash/src/multi_cache.rs new file mode 100644 index 000000000..df2f6a011 --- /dev/null +++ b/eth2/utils/cached_tree_hash/src/multi_cache.rs @@ -0,0 +1,62 @@ +use crate::{int_log, CachedTreeHash, Error, Hash256, TreeHashCache}; +use ssz_derive::{Decode, Encode}; +use ssz_types::{typenum::Unsigned, VariableList}; +use tree_hash::mix_in_length; + +/// Multi-level tree hash cache. +/// +/// Suitable for lists/vectors/containers holding values which themselves have caches. +/// +/// Note: this cache could be made composable by replacing the hardcoded `Vec` with +/// `Vec`, allowing arbitrary nesting, but for now we stick to 2-level nesting because that's all +/// we need. +#[derive(Debug, PartialEq, Clone, Default, Encode, Decode)] +pub struct MultiTreeHashCache { + list_cache: TreeHashCache, + value_caches: Vec, +} + +impl CachedTreeHash for VariableList +where + T: CachedTreeHash, + N: Unsigned, +{ + fn new_tree_hash_cache() -> MultiTreeHashCache { + MultiTreeHashCache { + list_cache: TreeHashCache::new(int_log(N::to_usize())), + value_caches: vec![], + } + } + + fn recalculate_tree_hash_root(&self, cache: &mut MultiTreeHashCache) -> Result { + if self.len() < cache.value_caches.len() { + return Err(Error::CannotShrink); + } + + // Resize the value caches to the size of the list. + cache + .value_caches + .resize(self.len(), T::new_tree_hash_cache()); + + // Update all individual value caches. + self.iter() + .zip(cache.value_caches.iter_mut()) + .try_for_each(|(value, cache)| value.recalculate_tree_hash_root(cache).map(|_| ()))?; + + // Pipe the value roots into the list cache, then mix in the length. + // Note: it's possible to avoid this 2nd iteration (or an allocation) by using + // `itertools::process_results`, but it requires removing the `ExactSizeIterator` + // bound from `recalculate_merkle_root`, and only saves about 5% in benchmarks. + let list_root = cache.list_cache.recalculate_merkle_root( + cache + .value_caches + .iter() + .map(|value_cache| value_cache.root().to_fixed_bytes()), + )?; + + Ok(Hash256::from_slice(&mix_in_length( + list_root.as_bytes(), + self.len(), + ))) + } +} diff --git a/eth2/utils/cached_tree_hash/src/test.rs b/eth2/utils/cached_tree_hash/src/test.rs new file mode 100644 index 000000000..68173fd6a --- /dev/null +++ b/eth2/utils/cached_tree_hash/src/test.rs @@ -0,0 +1,147 @@ +use crate::impls::hash256_iter; +use crate::{CachedTreeHash, Error, Hash256, TreeHashCache}; +use eth2_hashing::ZERO_HASHES; +use quickcheck_macros::quickcheck; +use ssz_types::{ + typenum::{Unsigned, U16, U255, U256, U257}, + FixedVector, VariableList, +}; +use tree_hash::TreeHash; + +fn int_hashes(start: u64, end: u64) -> Vec { + (start..end).map(Hash256::from_low_u64_le).collect() +} + +type List16 = VariableList; +type Vector16 = FixedVector; +type Vector16u64 = FixedVector; + +#[test] +fn max_leaves() { + let depth = 4; + let max_len = 2u64.pow(depth as u32); + let mut cache = TreeHashCache::new(depth); + assert!(cache + .recalculate_merkle_root(hash256_iter(&int_hashes(0, max_len - 1))) + .is_ok()); + assert!(cache + .recalculate_merkle_root(hash256_iter(&int_hashes(0, max_len))) + .is_ok()); + assert_eq!( + cache.recalculate_merkle_root(hash256_iter(&int_hashes(0, max_len + 1))), + Err(Error::TooManyLeaves) + ); + assert_eq!( + cache.recalculate_merkle_root(hash256_iter(&int_hashes(0, max_len * 2))), + Err(Error::TooManyLeaves) + ); +} + +#[test] +fn cannot_shrink() { + let init_len = 12; + let list1 = List16::new(int_hashes(0, init_len)).unwrap(); + let list2 = List16::new(int_hashes(0, init_len - 1)).unwrap(); + + let mut cache = List16::new_tree_hash_cache(); + assert!(list1.recalculate_tree_hash_root(&mut cache).is_ok()); + assert_eq!( + list2.recalculate_tree_hash_root(&mut cache), + Err(Error::CannotShrink) + ); +} + +#[test] +fn empty_leaves() { + let depth = 20; + let mut cache = TreeHashCache::new(depth); + assert_eq!( + cache + .recalculate_merkle_root(vec![].into_iter()) + .unwrap() + .as_bytes(), + &ZERO_HASHES[depth][..] + ); +} + +#[test] +fn fixed_vector_hash256() { + let len = 16; + let vec = Vector16::new(int_hashes(0, len)).unwrap(); + + let mut cache = Vector16::new_tree_hash_cache(); + + assert_eq!( + Hash256::from_slice(&vec.tree_hash_root()), + vec.recalculate_tree_hash_root(&mut cache).unwrap() + ); +} + +#[test] +fn fixed_vector_u64() { + let len = 16; + let vec = Vector16u64::new((0..len).collect()).unwrap(); + + let mut cache = Vector16u64::new_tree_hash_cache(); + + assert_eq!( + Hash256::from_slice(&vec.tree_hash_root()), + vec.recalculate_tree_hash_root(&mut cache).unwrap() + ); +} + +#[test] +fn variable_list_hash256() { + let len = 13; + let list = List16::new(int_hashes(0, len)).unwrap(); + + let mut cache = List16::new_tree_hash_cache(); + + assert_eq!( + Hash256::from_slice(&list.tree_hash_root()), + list.recalculate_tree_hash_root(&mut cache).unwrap() + ); +} + +#[quickcheck] +fn quickcheck_variable_list_h256_256(leaves_and_skips: Vec<(u64, bool)>) -> bool { + variable_list_h256_test::(leaves_and_skips) +} + +#[quickcheck] +fn quickcheck_variable_list_h256_255(leaves_and_skips: Vec<(u64, bool)>) -> bool { + variable_list_h256_test::(leaves_and_skips) +} + +#[quickcheck] +fn quickcheck_variable_list_h256_257(leaves_and_skips: Vec<(u64, bool)>) -> bool { + variable_list_h256_test::(leaves_and_skips) +} + +fn variable_list_h256_test(leaves_and_skips: Vec<(u64, bool)>) -> bool { + let leaves: Vec<_> = leaves_and_skips + .iter() + .map(|(l, _)| Hash256::from_low_u64_be(*l)) + .take(Len::to_usize()) + .collect(); + + let mut list: VariableList; + let mut cache = VariableList::::new_tree_hash_cache(); + + for (end, (_, update_cache)) in leaves_and_skips.into_iter().enumerate() { + list = VariableList::new(leaves[..end].to_vec()).unwrap(); + + if update_cache { + if list + .recalculate_tree_hash_root(&mut cache) + .unwrap() + .as_bytes() + != &list.tree_hash_root()[..] + { + return false; + } + } + } + + true +} diff --git a/eth2/utils/eth2_hashing/Cargo.toml b/eth2/utils/eth2_hashing/Cargo.toml index af48d0d4e..3047a7a4d 100644 --- a/eth2/utils/eth2_hashing/Cargo.toml +++ b/eth2/utils/eth2_hashing/Cargo.toml @@ -1,11 +1,14 @@ [package] name = "eth2_hashing" -version = "0.1.0" +version = "0.1.1" authors = ["Paul Hauner "] edition = "2018" license = "Apache-2.0" description = "Hashing primitives used in Ethereum 2.0" +[dependencies] +lazy_static = { version = "1.4.0", optional = true } + [target.'cfg(not(target_arch = "wasm32"))'.dependencies] ring = "0.16.9" @@ -17,3 +20,7 @@ rustc-hex = "2.0.1" [target.'cfg(target_arch = "wasm32")'.dev-dependencies] wasm-bindgen-test = "0.3.2" + +[features] +default = ["zero_hash_cache"] +zero_hash_cache = ["lazy_static"] diff --git a/eth2/utils/eth2_hashing/src/lib.rs b/eth2/utils/eth2_hashing/src/lib.rs index 94d072d8d..555c5bbe3 100644 --- a/eth2/utils/eth2_hashing/src/lib.rs +++ b/eth2/utils/eth2_hashing/src/lib.rs @@ -10,6 +10,9 @@ use ring::digest::{digest, SHA256}; #[cfg(target_arch = "wasm32")] use sha2::{Digest, Sha256}; +#[cfg(feature = "zero_hash_cache")] +use lazy_static::lazy_static; + /// Returns the digest of `input`. /// /// Uses `ring::digest::SHA256`. @@ -23,6 +26,31 @@ pub fn hash(input: &[u8]) -> Vec { h } +/// Compute the hash of two slices concatenated. +pub fn hash_concat(h1: &[u8], h2: &[u8]) -> Vec { + let mut vec1 = h1.to_vec(); + vec1.extend_from_slice(h2); + hash(&vec1) +} + +/// The max index that can be used with `ZERO_HASHES`. +#[cfg(feature = "zero_hash_cache")] +pub const ZERO_HASHES_MAX_INDEX: usize = 48; + +#[cfg(feature = "zero_hash_cache")] +lazy_static! { + /// Cached zero hashes where `ZERO_HASHES[i]` is the hash of a Merkle tree with 2^i zero leaves. + pub static ref ZERO_HASHES: Vec> = { + let mut hashes = vec![vec![0; 32]; ZERO_HASHES_MAX_INDEX + 1]; + + for i in 0..ZERO_HASHES_MAX_INDEX { + hashes[i + 1] = hash_concat(&hashes[i], &hashes[i]); + } + + hashes + }; +} + #[cfg(test)] mod tests { use super::*; @@ -41,4 +69,14 @@ mod tests { let expected: Vec = expected_hex.from_hex().unwrap(); assert_eq!(expected, output); } + + #[cfg(feature = "zero_hash_cache")] + mod zero_hash { + use super::*; + + #[test] + fn zero_hash_zero() { + assert_eq!(ZERO_HASHES[0], vec![0; 32]); + } + } } diff --git a/eth2/utils/merkle_proof/src/lib.rs b/eth2/utils/merkle_proof/src/lib.rs index 785072eb4..356c66835 100644 --- a/eth2/utils/merkle_proof/src/lib.rs +++ b/eth2/utils/merkle_proof/src/lib.rs @@ -1,24 +1,11 @@ -#[macro_use] -extern crate lazy_static; - -use eth2_hashing::hash; +use eth2_hashing::{hash, hash_concat, ZERO_HASHES}; use ethereum_types::H256; +use lazy_static::lazy_static; const MAX_TREE_DEPTH: usize = 32; const EMPTY_SLICE: &[H256] = &[]; lazy_static! { - /// Cached zero hashes where `ZERO_HASHES[i]` is the hash of a Merkle tree with 2^i zero leaves. - static ref ZERO_HASHES: Vec = { - let mut hashes = vec![H256::from([0; 32]); MAX_TREE_DEPTH + 1]; - - for i in 0..MAX_TREE_DEPTH { - hashes[i + 1] = hash_concat(hashes[i], hashes[i]); - } - - hashes - }; - /// Zero nodes to act as "synthetic" left and right subtrees of other zero nodes. static ref ZERO_NODES: Vec = { (0..=MAX_TREE_DEPTH).map(MerkleTree::Zero).collect() @@ -78,7 +65,10 @@ impl MerkleTree { let left_subtree = MerkleTree::create(left_leaves, depth - 1); let right_subtree = MerkleTree::create(right_leaves, depth - 1); - let hash = hash_concat(left_subtree.hash(), right_subtree.hash()); + let hash = H256::from_slice(&hash_concat( + left_subtree.hash().as_bytes(), + right_subtree.hash().as_bytes(), + )); Node(hash, Box::new(left_subtree), Box::new(right_subtree)) } @@ -146,7 +136,7 @@ impl MerkleTree { match *self { MerkleTree::Leaf(h) => h, MerkleTree::Node(h, _, _) => h, - MerkleTree::Zero(depth) => ZERO_HASHES[depth], + MerkleTree::Zero(depth) => H256::from_slice(&ZERO_HASHES[depth]), } } @@ -228,8 +218,7 @@ fn merkle_root_from_branch(leaf: H256, branch: &[H256], depth: usize, index: usi for (i, leaf) in branch.iter().enumerate().take(depth) { let ith_bit = (index >> i) & 0x01; if ith_bit == 1 { - let input = concat(leaf.as_bytes().to_vec(), merkle_root); - merkle_root = hash(&input); + merkle_root = hash_concat(leaf.as_bytes(), &merkle_root); } else { let mut input = merkle_root; input.extend_from_slice(leaf.as_bytes()); @@ -240,20 +229,6 @@ fn merkle_root_from_branch(leaf: H256, branch: &[H256], depth: usize, index: usi H256::from_slice(&merkle_root) } -/// Concatenate two vectors. -fn concat(mut vec1: Vec, mut vec2: Vec) -> Vec { - vec1.append(&mut vec2); - vec1 -} - -/// Compute the hash of two other hashes concatenated. -fn hash_concat(h1: H256, h2: H256) -> H256 { - H256::from_slice(&hash(&concat( - h1.as_bytes().to_vec(), - h2.as_bytes().to_vec(), - ))) -} - #[cfg(test)] mod tests { use super::*; @@ -318,10 +293,10 @@ mod tests { let leaf_b10 = H256::from([0xCC; 32]); let leaf_b11 = H256::from([0xDD; 32]); - let node_b0x = hash_concat(leaf_b00, leaf_b01); - let node_b1x = hash_concat(leaf_b10, leaf_b11); + let node_b0x = H256::from_slice(&hash_concat(leaf_b00.as_bytes(), leaf_b01.as_bytes())); + let node_b1x = H256::from_slice(&hash_concat(leaf_b10.as_bytes(), leaf_b11.as_bytes())); - let root = hash_concat(node_b0x, node_b1x); + let root = H256::from_slice(&hash_concat(node_b0x.as_bytes(), node_b1x.as_bytes())); let tree = MerkleTree::create(&[leaf_b00, leaf_b01, leaf_b10, leaf_b11], 2); assert_eq!(tree.hash(), root); @@ -335,10 +310,10 @@ mod tests { let leaf_b10 = H256::from([0xCC; 32]); let leaf_b11 = H256::from([0xDD; 32]); - let node_b0x = hash_concat(leaf_b00, leaf_b01); - let node_b1x = hash_concat(leaf_b10, leaf_b11); + let node_b0x = H256::from_slice(&hash_concat(leaf_b00.as_bytes(), leaf_b01.as_bytes())); + let node_b1x = H256::from_slice(&hash_concat(leaf_b10.as_bytes(), leaf_b11.as_bytes())); - let root = hash_concat(node_b0x, node_b1x); + let root = H256::from_slice(&hash_concat(node_b0x.as_bytes(), node_b1x.as_bytes())); // Run some proofs assert!(verify_merkle_proof( diff --git a/eth2/utils/tree_hash/Cargo.toml b/eth2/utils/tree_hash/Cargo.toml index e416a3f8e..7d48b1707 100644 --- a/eth2/utils/tree_hash/Cargo.toml +++ b/eth2/utils/tree_hash/Cargo.toml @@ -15,8 +15,8 @@ criterion = "0.3.0" rand = "0.7.2" tree_hash_derive = "0.2" types = { path = "../../types" } +lazy_static = "1.4.0" [dependencies] ethereum-types = "0.8.0" eth2_hashing = "0.1.0" -lazy_static = "1.4.0" diff --git a/eth2/utils/tree_hash/benches/benches.rs b/eth2/utils/tree_hash/benches/benches.rs index bad6f3a39..d734a7342 100644 --- a/eth2/utils/tree_hash/benches/benches.rs +++ b/eth2/utils/tree_hash/benches/benches.rs @@ -1,8 +1,6 @@ -#[macro_use] -extern crate lazy_static; - use criterion::Criterion; use criterion::{black_box, criterion_group, criterion_main, Benchmark}; +use lazy_static::lazy_static; use types::test_utils::{generate_deterministic_keypairs, TestingBeaconStateBuilder}; use types::{BeaconState, EthSpec, Keypair, MainnetEthSpec, MinimalEthSpec}; @@ -27,25 +25,61 @@ fn build_state(validator_count: usize) -> BeaconState { state } +// Note: `state.canonical_root()` uses whatever `tree_hash` that the `types` crate +// uses, which is not necessarily this crate. If you want to ensure that types is +// using this local version of `tree_hash`, ensure you add a workspace-level +// [dependency +// patch](https://doc.rust-lang.org/cargo/reference/manifest.html#the-patch-section). fn bench_suite(c: &mut Criterion, spec_desc: &str, validator_count: usize) { - let state = build_state::(validator_count); + let state1 = build_state::(validator_count); + let state2 = state1.clone(); + let mut state3 = state1.clone(); + state3.build_tree_hash_cache().unwrap(); c.bench( - &format!("{}/{}_validators", spec_desc, validator_count), + &format!("{}/{}_validators/no_cache", spec_desc, validator_count), Benchmark::new("genesis_state", move |b| { b.iter_batched_ref( - || state.clone(), - // Note: `state.canonical_root()` uses whatever `tree_hash` that the `types` crate - // uses, which is not necessarily this crate. If you want to ensure that types is - // using this local version of `tree_hash`, ensure you add a workspace-level - // [dependency - // patch](https://doc.rust-lang.org/cargo/reference/manifest.html#the-patch-section). + || state1.clone(), |state| black_box(state.canonical_root()), criterion::BatchSize::SmallInput, ) }) .sample_size(10), ); + + c.bench( + &format!("{}/{}_validators/empty_cache", spec_desc, validator_count), + Benchmark::new("genesis_state", move |b| { + b.iter_batched_ref( + || state2.clone(), + |state| { + assert!(!state.tree_hash_cache.is_initialized()); + black_box(state.update_tree_hash_cache().unwrap()) + }, + criterion::BatchSize::SmallInput, + ) + }) + .sample_size(10), + ); + + c.bench( + &format!( + "{}/{}_validators/up_to_date_cache", + spec_desc, validator_count + ), + Benchmark::new("genesis_state", move |b| { + b.iter_batched_ref( + || state3.clone(), + |state| { + assert!(state.tree_hash_cache.is_initialized()); + black_box(state.update_tree_hash_cache().unwrap()) + }, + criterion::BatchSize::SmallInput, + ) + }) + .sample_size(10), + ); } fn all_benches(c: &mut Criterion) { diff --git a/eth2/utils/tree_hash/src/impls.rs b/eth2/utils/tree_hash/src/impls.rs index 9f09f50ce..25630cf97 100644 --- a/eth2/utils/tree_hash/src/impls.rs +++ b/eth2/utils/tree_hash/src/impls.rs @@ -131,36 +131,6 @@ impl TreeHash for H256 { } } -// TODO: this implementation always panics, it only exists to allow us to compile whilst -// refactoring tree hash. Should be removed. -macro_rules! impl_for_list { - ($type: ty) => { - impl TreeHash for $type - where - T: TreeHash, - { - fn tree_hash_type() -> TreeHashType { - unimplemented!("TreeHash is not implemented for Vec or slice") - } - - fn tree_hash_packed_encoding(&self) -> Vec { - unimplemented!("TreeHash is not implemented for Vec or slice") - } - - fn tree_hash_packing_factor() -> usize { - unimplemented!("TreeHash is not implemented for Vec or slice") - } - - fn tree_hash_root(&self) -> Vec { - unimplemented!("TreeHash is not implemented for Vec or slice") - } - } - }; -} - -impl_for_list!(Vec); -impl_for_list!(&[T]); - /// Returns `int` as little-endian bytes with a length of 32. fn int_to_bytes32(int: u64) -> Vec { let mut vec = int.to_le_bytes().to_vec(); diff --git a/eth2/utils/tree_hash/src/lib.rs b/eth2/utils/tree_hash/src/lib.rs index 72a77f03e..0b3be72c4 100644 --- a/eth2/utils/tree_hash/src/lib.rs +++ b/eth2/utils/tree_hash/src/lib.rs @@ -1,6 +1,3 @@ -#[macro_use] -extern crate lazy_static; - pub mod impls; mod merkleize_padded; mod merkleize_standard; @@ -27,7 +24,7 @@ pub fn mix_in_length(root: &[u8], length: usize) -> Vec { let mut length_bytes = length.to_le_bytes().to_vec(); length_bytes.resize(BYTES_PER_CHUNK, 0); - merkleize_padded::hash_concat(root, &length_bytes) + eth2_hashing::hash_concat(root, &length_bytes) } #[derive(Debug, PartialEq, Clone)] diff --git a/eth2/utils/tree_hash/src/merkleize_padded.rs b/eth2/utils/tree_hash/src/merkleize_padded.rs index bfec55e1c..832c0bbd8 100644 --- a/eth2/utils/tree_hash/src/merkleize_padded.rs +++ b/eth2/utils/tree_hash/src/merkleize_padded.rs @@ -1,25 +1,10 @@ use super::BYTES_PER_CHUNK; -use eth2_hashing::hash; +use eth2_hashing::{hash, hash_concat, ZERO_HASHES, ZERO_HASHES_MAX_INDEX}; /// The size of the cache that stores padding nodes for a given height. /// /// Currently, we panic if we encounter a tree with a height larger than `MAX_TREE_DEPTH`. -/// -/// It is set to 48 as we expect it to be sufficiently high that we won't exceed it. -pub const MAX_TREE_DEPTH: usize = 48; - -lazy_static! { - /// Cached zero hashes where `ZERO_HASHES[i]` is the hash of a Merkle tree with 2^i zero leaves. - static ref ZERO_HASHES: Vec> = { - let mut hashes = vec![vec![0; 32]; MAX_TREE_DEPTH + 1]; - - for i in 0..MAX_TREE_DEPTH { - hashes[i + 1] = hash_concat(&hashes[i], &hashes[i]); - } - - hashes - }; -} +pub const MAX_TREE_DEPTH: usize = ZERO_HASHES_MAX_INDEX; /// Merkleize `bytes` and return the root, optionally padding the tree out to `min_leaves` number of /// leaves. @@ -236,17 +221,6 @@ fn get_zero_hash(height: usize) -> &'static [u8] { } } -/// Concatenate two vectors. -fn concat(mut vec1: Vec, mut vec2: Vec) -> Vec { - vec1.append(&mut vec2); - vec1 -} - -/// Compute the hash of two other hashes concatenated. -pub fn hash_concat(h1: &[u8], h2: &[u8]) -> Vec { - hash(&concat(h1.to_vec(), h2.to_vec())) -} - /// Returns the next even number following `n`. If `n` is even, `n` is returned. fn next_even_number(n: usize) -> usize { n + n % 2 diff --git a/eth2/utils/tree_hash_derive/src/lib.rs b/eth2/utils/tree_hash_derive/src/lib.rs index e396b4fda..bd869dee5 100644 --- a/eth2/utils/tree_hash_derive/src/lib.rs +++ b/eth2/utils/tree_hash_derive/src/lib.rs @@ -3,14 +3,25 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, DeriveInput}; +use std::collections::HashMap; +use syn::{parse_macro_input, Attribute, DeriveInput, Meta}; -/// Returns a Vec of `syn::Ident` for each named field in the struct, whilst filtering out fields +/// Return a Vec of `syn::Ident` for each named field in the struct, whilst filtering out fields /// that should not be hashed. /// /// # Panics /// Any unnamed struct field (like in a tuple struct) will raise a panic at compile time. -fn get_hashable_named_field_idents<'a>(struct_data: &'a syn::DataStruct) -> Vec<&'a syn::Ident> { +fn get_hashable_fields<'a>(struct_data: &'a syn::DataStruct) -> Vec<&'a syn::Ident> { + get_hashable_fields_and_their_caches(struct_data) + .into_iter() + .map(|(ident, _, _)| ident) + .collect() +} + +/// Return a Vec of the hashable fields of a struct, and each field's type and optional cache field. +fn get_hashable_fields_and_their_caches<'a>( + struct_data: &'a syn::DataStruct, +) -> Vec<(&'a syn::Ident, syn::Type, Option)> { struct_data .fields .iter() @@ -18,15 +29,77 @@ fn get_hashable_named_field_idents<'a>(struct_data: &'a syn::DataStruct) -> Vec< if should_skip_hashing(&f) { None } else { - Some(match &f.ident { - Some(ref ident) => ident, - _ => panic!("tree_hash_derive only supports named struct fields."), - }) + let ident = f + .ident + .as_ref() + .expect("tree_hash_derive only supports named struct fields"); + let opt_cache_field = get_cache_field_for(&f); + Some((ident, f.ty.clone(), opt_cache_field)) } }) .collect() } +/// Parse the cached_tree_hash attribute for a field. +/// +/// Extract the cache field name from `#[cached_tree_hash(cache_field_name)]` +/// +/// Return `Some(cache_field_name)` if the field has a cached tree hash attribute, +/// or `None` otherwise. +fn get_cache_field_for<'a>(field: &'a syn::Field) -> Option { + use syn::{MetaList, NestedMeta}; + + let parsed_attrs = cached_tree_hash_attr_metas(&field.attrs); + if let [Meta::List(MetaList { nested, .. })] = &parsed_attrs[..] { + nested.iter().find_map(|x| match x { + NestedMeta::Meta(Meta::Word(cache_field_ident)) => Some(cache_field_ident.clone()), + _ => None, + }) + } else { + None + } +} + +/// Process the `cached_tree_hash` attributes from a list of attributes into structured `Meta`s. +fn cached_tree_hash_attr_metas(attrs: &[Attribute]) -> Vec { + attrs + .iter() + .filter(|attr| attr.path.is_ident("cached_tree_hash")) + .flat_map(|attr| attr.parse_meta()) + .collect() +} + +/// Parse the top-level cached_tree_hash struct attribute. +/// +/// Return the type from `#[cached_tree_hash(type = "T")]`. +/// +/// **Panics** if the attribute is missing or the type is malformed. +fn parse_cached_tree_hash_struct_attrs(attrs: &[Attribute]) -> syn::Type { + use syn::{Lit, MetaList, MetaNameValue, NestedMeta}; + + let parsed_attrs = cached_tree_hash_attr_metas(attrs); + if let [Meta::List(MetaList { nested, .. })] = &parsed_attrs[..] { + let eqns = nested + .iter() + .flat_map(|x| match x { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + ident, + lit: Lit::Str(lit_str), + .. + })) => Some((ident.to_string(), lit_str.clone())), + _ => None, + }) + .collect::>(); + + eqns["type"] + .clone() + .parse() + .expect("valid type required for cache") + } else { + panic!("missing attribute `#[cached_tree_hash(type = ...)` on struct"); + } +} + /// Returns true if some field has an attribute declaring it should not be hashed. /// /// The field attribute is: `#[tree_hash(skip_hashing)]` @@ -51,7 +124,7 @@ pub fn tree_hash_derive(input: TokenStream) -> TokenStream { _ => panic!("tree_hash_derive only supports structs."), }; - let idents = get_hashable_named_field_idents(&struct_data); + let idents = get_hashable_fields(&struct_data); let output = quote! { impl #impl_generics tree_hash::TreeHash for #name #ty_generics #where_clause { @@ -112,6 +185,82 @@ pub fn tree_hash_signed_root_derive(input: TokenStream) -> TokenStream { output.into() } +/// Derive the `CachedTreeHash` trait for a type. +/// +/// Requires two attributes: +/// * `#[cached_tree_hash(type = "T")]` on the struct, declaring +/// that the type `T` should be used as the tree hash cache. +/// * `#[cached_tree_hash(f)]` on each struct field that makes use +/// of the cache, which declares that the sub-cache for that field +/// can be found in the field `cache.f` of the struct's cache. +#[proc_macro_derive(CachedTreeHash, attributes(cached_tree_hash))] +pub fn cached_tree_hash_derive(input: TokenStream) -> TokenStream { + let item = parse_macro_input!(input as DeriveInput); + + let name = &item.ident; + + let cache_type = parse_cached_tree_hash_struct_attrs(&item.attrs); + + let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); + + let struct_data = match &item.data { + syn::Data::Struct(s) => s, + _ => panic!("tree_hash_derive only supports structs."), + }; + + let fields = get_hashable_fields_and_their_caches(&struct_data); + let caching_field_ty = fields + .iter() + .filter(|(_, _, cache_field)| cache_field.is_some()) + .map(|(_, ty, _)| ty); + let caching_field_cache_field = fields + .iter() + .flat_map(|(_, _, cache_field)| cache_field.as_ref()); + + let tree_hash_root_expr = fields + .iter() + .map(|(field, _, caching_field)| match caching_field { + None => quote! { + self.#field.tree_hash_root() + }, + Some(caching_field) => quote! { + self.#field + .recalculate_tree_hash_root(&mut cache.#caching_field)? + .as_bytes() + .to_vec() + }, + }); + + let output = quote! { + impl #impl_generics cached_tree_hash::CachedTreeHash<#cache_type> for #name #ty_generics #where_clause { + fn new_tree_hash_cache() -> #cache_type { + // Call new cache for each sub type + #cache_type { + initialized: true, + #( + #caching_field_cache_field: <#caching_field_ty>::new_tree_hash_cache() + ),* + } + } + + fn recalculate_tree_hash_root( + &self, + cache: &mut #cache_type) + -> Result + { + let mut leaves = vec![]; + + #( + leaves.append(&mut #tree_hash_root_expr); + )* + + Ok(Hash256::from_slice(&tree_hash::merkle_root(&leaves, 0))) + } + } + }; + output.into() +} + fn get_signed_root_named_field_idents(struct_data: &syn::DataStruct) -> Vec<&syn::Ident> { struct_data .fields diff --git a/tests/ef_tests/Cargo.toml b/tests/ef_tests/Cargo.toml index b0d281b8d..e893ea8e2 100644 --- a/tests/ef_tests/Cargo.toml +++ b/tests/ef_tests/Cargo.toml @@ -23,6 +23,7 @@ eth2_ssz = "0.1.2" eth2_ssz_derive = "0.1.0" tree_hash = "0.1.0" tree_hash_derive = "0.2" +cached_tree_hash = { path = "../../eth2/utils/cached_tree_hash" } state_processing = { path = "../../eth2/state_processing" } swap_or_not_shuffle = { path = "../../eth2/utils/swap_or_not_shuffle" } types = { path = "../../eth2/types" } diff --git a/tests/ef_tests/src/cases/ssz_generic.rs b/tests/ef_tests/src/cases/ssz_generic.rs index 442dd6e09..7aa198bea 100644 --- a/tests/ef_tests/src/cases/ssz_generic.rs +++ b/tests/ef_tests/src/cases/ssz_generic.rs @@ -218,7 +218,7 @@ fn ssz_generic_test(path: &Path) -> Result<(), Error> { check_serialization(&value, &serialized)?; if let Some(ref meta) = meta { - check_tree_hash(&meta.root, value.tree_hash_root())?; + check_tree_hash(&meta.root, &value.tree_hash_root())?; } } // Invalid diff --git a/tests/ef_tests/src/cases/ssz_static.rs b/tests/ef_tests/src/cases/ssz_static.rs index 62f285d58..e4c216f76 100644 --- a/tests/ef_tests/src/cases/ssz_static.rs +++ b/tests/ef_tests/src/cases/ssz_static.rs @@ -2,8 +2,10 @@ use super::*; use crate::case_result::compare_result; use crate::cases::common::SszStaticType; use crate::decode::yaml_decode_file; +use cached_tree_hash::CachedTreeHash; use serde_derive::Deserialize; use std::fs; +use std::marker::PhantomData; use tree_hash::SignedRoot; use types::Hash256; @@ -27,6 +29,14 @@ pub struct SszStaticSR { value: T, } +#[derive(Debug, Clone)] +pub struct SszStaticTHC { + roots: SszStaticRoots, + serialized: Vec, + value: T, + _phantom: PhantomData, +} + fn load_from_dir(path: &Path) -> Result<(SszStaticRoots, Vec, T), Error> { let roots = yaml_decode_file(&path.join("roots.yaml"))?; let serialized = fs::read(&path.join("serialized.ssz")).expect("serialized.ssz exists"); @@ -55,6 +65,17 @@ impl LoadCase for SszStaticSR { } } +impl, C: Debug + Sync> LoadCase for SszStaticTHC { + fn load_from_dir(path: &Path) -> Result { + load_from_dir(path).map(|(roots, serialized, value)| Self { + roots, + serialized, + value, + _phantom: PhantomData, + }) + } +} + pub fn check_serialization(value: &T, serialized: &[u8]) -> Result<(), Error> { // Check serialization let serialized_result = value.as_ssz_bytes(); @@ -68,18 +89,18 @@ pub fn check_serialization(value: &T, serialized: &[u8]) -> Re Ok(()) } -pub fn check_tree_hash(expected_str: &str, actual_root: Vec) -> Result<(), Error> { +pub fn check_tree_hash(expected_str: &str, actual_root: &[u8]) -> Result<(), Error> { let expected_root = hex::decode(&expected_str[2..]) .map_err(|e| Error::FailedToParseTest(format!("{:?}", e)))?; let expected_root = Hash256::from_slice(&expected_root); - let tree_hash_root = Hash256::from_slice(&actual_root); + let tree_hash_root = Hash256::from_slice(actual_root); compare_result::(&Ok(tree_hash_root), &Some(expected_root)) } impl Case for SszStatic { fn result(&self, _case_index: usize) -> Result<(), Error> { check_serialization(&self.value, &self.serialized)?; - check_tree_hash(&self.roots.root, self.value.tree_hash_root())?; + check_tree_hash(&self.roots.root, &self.value.tree_hash_root())?; Ok(()) } } @@ -87,15 +108,28 @@ impl Case for SszStatic { impl Case for SszStaticSR { fn result(&self, _case_index: usize) -> Result<(), Error> { check_serialization(&self.value, &self.serialized)?; - check_tree_hash(&self.roots.root, self.value.tree_hash_root())?; + check_tree_hash(&self.roots.root, &self.value.tree_hash_root())?; check_tree_hash( &self .roots .signing_root .as_ref() .expect("signed root exists"), - self.value.signed_root(), + &self.value.signed_root(), )?; Ok(()) } } + +impl, C: Debug + Sync> Case for SszStaticTHC { + fn result(&self, _case_index: usize) -> Result<(), Error> { + check_serialization(&self.value, &self.serialized)?; + check_tree_hash(&self.roots.root, &self.value.tree_hash_root())?; + + let mut cache = T::new_tree_hash_cache(); + let cached_tree_hash_root = self.value.recalculate_tree_hash_root(&mut cache).unwrap(); + check_tree_hash(&self.roots.root, cached_tree_hash_root.as_bytes())?; + + Ok(()) + } +} diff --git a/tests/ef_tests/src/handler.rs b/tests/ef_tests/src/handler.rs index e5d175e11..df2b6603b 100644 --- a/tests/ef_tests/src/handler.rs +++ b/tests/ef_tests/src/handler.rs @@ -1,6 +1,8 @@ use crate::cases::{self, Case, Cases, EpochTransition, LoadCase, Operation}; use crate::type_name; use crate::type_name::TypeName; +use cached_tree_hash::CachedTreeHash; +use std::fmt::Debug; use std::fs; use std::marker::PhantomData; use std::path::PathBuf; @@ -93,6 +95,9 @@ pub struct SszStaticHandler(PhantomData<(T, E)>); /// Handler for SSZ types that do implement `SignedRoot`. pub struct SszStaticSRHandler(PhantomData<(T, E)>); +/// Handler for SSZ types that implement `CachedTreeHash`. +pub struct SszStaticTHCHandler(PhantomData<(T, C, E)>); + impl Handler for SszStaticHandler where T: cases::SszStaticType + TypeName, @@ -133,6 +138,27 @@ where } } +impl Handler for SszStaticTHCHandler +where + T: cases::SszStaticType + CachedTreeHash + TypeName, + C: Debug + Sync, + E: TypeName, +{ + type Case = cases::SszStaticTHC; + + fn config_name() -> &'static str { + E::name() + } + + fn runner_name() -> &'static str { + "ssz_static" + } + + fn handler_name() -> String { + T::name().into() + } +} + pub struct ShufflingHandler(PhantomData); impl Handler for ShufflingHandler { diff --git a/tests/ef_tests/tests/tests.rs b/tests/ef_tests/tests/tests.rs index 43cb79a8d..a9b7c22ac 100644 --- a/tests/ef_tests/tests/tests.rs +++ b/tests/ef_tests/tests/tests.rs @@ -99,7 +99,7 @@ macro_rules! ssz_static_test { ($test_name:ident, $typ:ident$(<$generics:tt>)?, SR) => { ssz_static_test!($test_name, SszStaticSRHandler, $typ$(<$generics>)?); }; - // Non-signed root + // Non-signed root, non-tree hash caching ($test_name:ident, $typ:ident$(<$generics:tt>)?) => { ssz_static_test!($test_name, SszStaticHandler, $typ$(<$generics>)?); }; @@ -122,11 +122,11 @@ macro_rules! ssz_static_test { ); }; // Base case - ($test_name:ident, $handler:ident, { $(($typ:ty, $spec:ident)),+ }) => { + ($test_name:ident, $handler:ident, { $(($($typ:ty),+)),+ }) => { #[test] fn $test_name() { $( - $handler::<$typ, $spec>::run(); + $handler::<$($typ),+>::run(); )+ } }; @@ -134,7 +134,7 @@ macro_rules! ssz_static_test { #[cfg(feature = "fake_crypto")] mod ssz_static { - use ef_tests::{Handler, SszStaticHandler, SszStaticSRHandler}; + use ef_tests::{Handler, SszStaticHandler, SszStaticSRHandler, SszStaticTHCHandler}; use types::*; ssz_static_test!(attestation, Attestation<_>, SR); @@ -147,7 +147,13 @@ mod ssz_static { ssz_static_test!(beacon_block, BeaconBlock<_>, SR); ssz_static_test!(beacon_block_body, BeaconBlockBody<_>); ssz_static_test!(beacon_block_header, BeaconBlockHeader, SR); - ssz_static_test!(beacon_state, BeaconState<_>); + ssz_static_test!( + beacon_state, + SszStaticTHCHandler, { + (BeaconState, BeaconTreeHashCache, MinimalEthSpec), + (BeaconState, BeaconTreeHashCache, MainnetEthSpec) + } + ); ssz_static_test!(checkpoint, Checkpoint); ssz_static_test!(compact_committee, CompactCommittee<_>); ssz_static_test!(crosslink, Crosslink);