From ebe47a5b341b1bc710671f4a9dc22ddcf996a017 Mon Sep 17 00:00:00 2001 From: Paul Hauner Date: Mon, 1 Apr 2019 14:56:32 +1100 Subject: [PATCH] Add `Store` and `db_encode_derive`. Implementation is not complete, but what is here works. --- Cargo.toml | 3 + beacon_node/db/Cargo.toml | 2 - beacon_node/db2/Cargo.toml | 16 + beacon_node/db2/src/disk_db.rs | 199 ++++++++++++ beacon_node/db2/src/lib.rs | 151 +++++++++ beacon_node/db2/src/memory_db.rs | 236 ++++++++++++++ .../db2/src/stores/beacon_block_store.rs | 246 ++++++++++++++ .../db2/src/stores/beacon_state_store.rs | 62 ++++ beacon_node/db2/src/stores/macros.rs | 103 ++++++ beacon_node/db2/src/stores/mod.rs | 25 ++ beacon_node/db2/src/stores/pow_chain_store.rs | 68 ++++ beacon_node/db2/src/stores/validator_store.rs | 215 ++++++++++++ beacon_node/db2/src/traits.rs | 38 +++ beacon_node/db_encode/Cargo.toml | 9 + beacon_node/db_encode/src/lib.rs | 59 ++++ beacon_node/db_encode_derive/Cargo.toml | 13 + beacon_node/db_encode_derive/src/lib.rs | 305 ++++++++++++++++++ 17 files changed, 1748 insertions(+), 2 deletions(-) create mode 100644 beacon_node/db2/Cargo.toml create mode 100644 beacon_node/db2/src/disk_db.rs create mode 100644 beacon_node/db2/src/lib.rs create mode 100644 beacon_node/db2/src/memory_db.rs create mode 100644 beacon_node/db2/src/stores/beacon_block_store.rs create mode 100644 beacon_node/db2/src/stores/beacon_state_store.rs create mode 100644 beacon_node/db2/src/stores/macros.rs create mode 100644 beacon_node/db2/src/stores/mod.rs create mode 100644 beacon_node/db2/src/stores/pow_chain_store.rs create mode 100644 beacon_node/db2/src/stores/validator_store.rs create mode 100644 beacon_node/db2/src/traits.rs create mode 100644 beacon_node/db_encode/Cargo.toml create mode 100644 beacon_node/db_encode/src/lib.rs create mode 100644 beacon_node/db_encode_derive/Cargo.toml create mode 100644 beacon_node/db_encode_derive/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 3ae62248b..008e83bae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,9 @@ members = [ "eth2/utils/test_random_derive", "beacon_node", "beacon_node/db", + "beacon_node/db2", + "beacon_node/db_encode", + "beacon_node/db_encode_derive", "beacon_node/client", "beacon_node/network", "beacon_node/eth2-libp2p", diff --git a/beacon_node/db/Cargo.toml b/beacon_node/db/Cargo.toml index 122aaa34d..ffb3585b9 100644 --- a/beacon_node/db/Cargo.toml +++ b/beacon_node/db/Cargo.toml @@ -9,5 +9,3 @@ blake2-rfc = "0.2.18" bls = { path = "../../eth2/utils/bls" } bytes = "0.4.10" rocksdb = "0.10.1" -ssz = { path = "../../eth2/utils/ssz" } -types = { path = "../../eth2/types" } diff --git a/beacon_node/db2/Cargo.toml b/beacon_node/db2/Cargo.toml new file mode 100644 index 000000000..8a5dbad5e --- /dev/null +++ b/beacon_node/db2/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "db2" +version = "0.1.0" +authors = ["Paul Hauner "] +edition = "2018" + +[dependencies] +blake2-rfc = "0.2.18" +bls = { path = "../../eth2/utils/bls" } +bytes = "0.4.10" +db_encode = { path = "../db_encode" } +db_encode_derive = { path = "../db_encode_derive" } +rocksdb = "0.10.1" +ssz = { path = "../../eth2/utils/ssz" } +ssz_derive = { path = "../../eth2/utils/ssz_derive" } +types = { path = "../../eth2/types" } diff --git a/beacon_node/db2/src/disk_db.rs b/beacon_node/db2/src/disk_db.rs new file mode 100644 index 000000000..f05320f7f --- /dev/null +++ b/beacon_node/db2/src/disk_db.rs @@ -0,0 +1,199 @@ +extern crate rocksdb; + +use super::rocksdb::Error as RocksError; +use super::rocksdb::{Options, DB}; +use super::stores::COLUMNS; +use super::{ClientDB, DBError, DBValue}; +use std::fs; +use std::path::Path; + +/// A on-disk database which implements the ClientDB trait. +/// +/// This implementation uses RocksDB with default options. +pub struct DiskDB { + db: DB, +} + +impl DiskDB { + /// Open the RocksDB database, optionally supplying columns if required. + /// + /// The RocksDB database will be contained in a directory titled + /// "database" in the supplied path. + /// + /// # Panics + /// + /// Panics if the database is unable to be created. + pub fn open(path: &Path, columns: Option<&[&str]>) -> Self { + // Rocks options. + let mut options = Options::default(); + options.create_if_missing(true); + + // Ensure the path exists. + fs::create_dir_all(&path).unwrap_or_else(|_| panic!("Unable to create {:?}", &path)); + let db_path = path.join("database"); + + let columns = columns.unwrap_or(&COLUMNS); + + if db_path.exists() { + Self { + db: DB::open_cf(&options, db_path, &COLUMNS) + .expect("Unable to open local database"), + } + } else { + let mut db = Self { + db: DB::open(&options, db_path).expect("Unable to open local database"), + }; + + for cf in columns { + db.create_col(cf).unwrap(); + } + + db + } + } + + /// Create a RocksDB column family. Corresponds to the + /// `create_cf()` function on the RocksDB API. + #[allow(dead_code)] + fn create_col(&mut self, col: &str) -> Result<(), DBError> { + match self.db.create_cf(col, &Options::default()) { + Err(e) => Err(e.into()), + Ok(_) => Ok(()), + } + } +} + +impl From for DBError { + fn from(e: RocksError) -> Self { + Self { + message: e.to_string(), + } + } +} + +impl ClientDB for DiskDB { + /// Get the value for some key on some column. + /// + /// Corresponds to the `get_cf()` method on the RocksDB API. + /// Will attempt to get the `ColumnFamily` and return an Err + /// if it fails. + fn get(&self, col: &str, key: &[u8]) -> Result, DBError> { + match self.db.cf_handle(col) { + None => Err(DBError { + message: "Unknown column".to_string(), + }), + Some(handle) => match self.db.get_cf(handle, key)? { + None => Ok(None), + Some(db_vec) => Ok(Some(DBValue::from(&*db_vec))), + }, + } + } + + /// Set some value for some key on some column. + /// + /// Corresponds to the `cf_handle()` method on the RocksDB API. + /// Will attempt to get the `ColumnFamily` and return an Err + /// if it fails. + fn put(&self, col: &str, key: &[u8], val: &[u8]) -> Result<(), DBError> { + match self.db.cf_handle(col) { + None => Err(DBError { + message: "Unknown column".to_string(), + }), + Some(handle) => self.db.put_cf(handle, key, val).map_err(|e| e.into()), + } + } + + /// Return true if some key exists in some column. + fn exists(&self, col: &str, key: &[u8]) -> Result { + /* + * I'm not sure if this is the correct way to read if some + * block exists. Naively I would expect this to unncessarily + * copy some data, but I could be wrong. + */ + match self.db.cf_handle(col) { + None => Err(DBError { + message: "Unknown column".to_string(), + }), + Some(handle) => Ok(self.db.get_cf(handle, key)?.is_some()), + } + } + + /// Delete the value for some key on some column. + /// + /// Corresponds to the `delete_cf()` method on the RocksDB API. + /// Will attempt to get the `ColumnFamily` and return an Err + /// if it fails. + fn delete(&self, col: &str, key: &[u8]) -> Result<(), DBError> { + match self.db.cf_handle(col) { + None => Err(DBError { + message: "Unknown column".to_string(), + }), + Some(handle) => { + self.db.delete_cf(handle, key)?; + Ok(()) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::super::ClientDB; + use super::*; + use std::sync::Arc; + use std::{env, fs, thread}; + + #[test] + #[ignore] + fn test_rocksdb_can_use_db() { + let pwd = env::current_dir().unwrap(); + let path = pwd.join("testdb_please_remove"); + let _ = fs::remove_dir_all(&path); + fs::create_dir_all(&path).unwrap(); + + let col_name: &str = "TestColumn"; + let column_families = vec![col_name]; + + let mut db = DiskDB::open(&path, None); + + for cf in column_families { + db.create_col(&cf).unwrap(); + } + + let db = Arc::new(db); + + let thread_count = 10; + let write_count = 10; + + // We're execting the product of these numbers to fit in one byte. + assert!(thread_count * write_count <= 255); + + let mut handles = vec![]; + for t in 0..thread_count { + let wc = write_count; + let db = db.clone(); + let col = col_name.clone(); + let handle = thread::spawn(move || { + for w in 0..wc { + let key = (t * w) as u8; + let val = 42; + db.put(&col, &vec![key], &vec![val]).unwrap(); + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + + for t in 0..thread_count { + for w in 0..write_count { + let key = (t * w) as u8; + let val = db.get(&col_name, &vec![key]).unwrap().unwrap(); + assert_eq!(vec![42], val); + } + } + fs::remove_dir_all(&path).unwrap(); + } +} diff --git a/beacon_node/db2/src/lib.rs b/beacon_node/db2/src/lib.rs new file mode 100644 index 000000000..0704a84f5 --- /dev/null +++ b/beacon_node/db2/src/lib.rs @@ -0,0 +1,151 @@ +extern crate blake2_rfc as blake2; +extern crate bls; +extern crate rocksdb; + +mod disk_db; +mod memory_db; +pub mod stores; +mod traits; + +use self::stores::COLUMNS; +use db_encode::{db_encode, DBDecode, DBEncode}; +use ssz::DecodeError; +use std::sync::Arc; + +pub use self::disk_db::DiskDB; +pub use self::memory_db::MemoryDB; +pub use self::traits::{ClientDB, DBError, DBValue}; +pub use types::*; + +#[derive(Debug, PartialEq)] +pub enum Error { + SszDecodeError(DecodeError), + DBError { message: String }, +} + +impl From for Error { + fn from(e: DecodeError) -> Error { + Error::SszDecodeError(e) + } +} + +impl From for Error { + fn from(e: DBError) -> Error { + Error::DBError { message: e.message } + } +} + +/// Currently available database options +#[derive(Debug, Clone)] +pub enum DBType { + Memory, + RocksDB, +} + +pub enum DBColumn { + Block, + State, + BeaconChain, +} + +impl<'a> Into<&'a str> for DBColumn { + /// Returns a `&str` that can be used for keying a key-value data base. + fn into(self) -> &'a str { + match self { + DBColumn::Block => &"blk", + DBColumn::State => &"ste", + DBColumn::BeaconChain => &"bch", + } + } +} + +pub trait DBRecord: DBEncode + DBDecode { + fn db_column() -> DBColumn; +} + +pub struct Store +where + T: ClientDB, +{ + db: Arc, +} + +impl Store { + fn new_in_memory() -> Self { + Self { + db: Arc::new(MemoryDB::open()), + } + } +} + +impl Store +where + T: ClientDB, +{ + /// Put `item` in the store as `key`. + /// + /// The `item` must implement `DBRecord` which defines the db column used. + fn put(&self, key: &Hash256, item: &I) -> Result<(), Error> + where + I: DBRecord, + { + let column = I::db_column().into(); + let key = key.as_bytes(); + let val = db_encode(item); + + self.db.put(column, key, &val).map_err(|e| e.into()) + } + + /// Retrieves an `Ok(Some(item)` from the store if `key` exists, otherwise returns `Ok(None)`. + /// + /// The `item` must implement `DBRecord` which defines the db column used. + fn get(&self, key: &Hash256) -> Result, Error> + where + I: DBRecord, + { + let column = I::db_column().into(); + let key = key.as_bytes(); + + match self.db.get(column, key)? { + Some(bytes) => { + let (item, _index) = I::db_decode(&bytes, 0)?; + Ok(Some(item)) + } + None => Ok(None), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use db_encode_derive::{DBDecode, DBEncode}; + use ssz::Decodable; + use ssz_derive::{Decode, Encode}; + + #[derive(PartialEq, Debug, Encode, Decode, DBEncode, DBDecode)] + struct StorableThing { + a: u64, + b: u64, + } + + impl DBRecord for StorableThing { + fn db_column() -> DBColumn { + DBColumn::Block + } + } + + #[test] + fn memorydb_can_store() { + let store = Store::new_in_memory(); + + let key = Hash256::random(); + let item = StorableThing { a: 1, b: 42 }; + + store.put(&key, &item).unwrap(); + + let retrieved = store.get(&key).unwrap().unwrap(); + + assert_eq!(item, retrieved); + } +} diff --git a/beacon_node/db2/src/memory_db.rs b/beacon_node/db2/src/memory_db.rs new file mode 100644 index 000000000..008e5912f --- /dev/null +++ b/beacon_node/db2/src/memory_db.rs @@ -0,0 +1,236 @@ +use super::blake2::blake2b::blake2b; +use super::COLUMNS; +use super::{ClientDB, DBError, DBValue}; +use std::collections::{HashMap, HashSet}; +use std::sync::RwLock; + +type DBHashMap = HashMap, Vec>; +type ColumnHashSet = HashSet; + +/// An in-memory database implementing the ClientDB trait. +/// +/// It is not particularily optimized, it exists for ease and speed of testing. It's not expected +/// this DB would be used outside of tests. +pub struct MemoryDB { + db: RwLock, + known_columns: RwLock, +} + +impl MemoryDB { + /// Open the in-memory database. + /// + /// All columns must be supplied initially, you will get an error if you try to access a column + /// that was not declared here. This condition is enforced artificially to simulate RocksDB. + pub fn open() -> Self { + let db: DBHashMap = HashMap::new(); + let mut known_columns: ColumnHashSet = HashSet::new(); + for col in &COLUMNS { + known_columns.insert(col.to_string()); + } + Self { + db: RwLock::new(db), + known_columns: RwLock::new(known_columns), + } + } + + /// Hashes a key and a column name in order to get a unique key for the supplied column. + fn get_key_for_col(col: &str, key: &[u8]) -> Vec { + blake2b(32, col.as_bytes(), key).as_bytes().to_vec() + } +} + +impl ClientDB for MemoryDB { + /// Get the value of some key from the database. Returns `None` if the key does not exist. + fn get(&self, col: &str, key: &[u8]) -> Result, DBError> { + // Panic if the DB locks are poisoned. + let db = self.db.read().unwrap(); + let known_columns = self.known_columns.read().unwrap(); + + if known_columns.contains(&col.to_string()) { + let column_key = MemoryDB::get_key_for_col(col, key); + Ok(db.get(&column_key).and_then(|val| Some(val.clone()))) + } else { + Err(DBError { + message: "Unknown column".to_string(), + }) + } + } + + /// Puts a key in the database. + fn put(&self, col: &str, key: &[u8], val: &[u8]) -> Result<(), DBError> { + // Panic if the DB locks are poisoned. + let mut db = self.db.write().unwrap(); + let known_columns = self.known_columns.read().unwrap(); + + if known_columns.contains(&col.to_string()) { + let column_key = MemoryDB::get_key_for_col(col, key); + db.insert(column_key, val.to_vec()); + Ok(()) + } else { + Err(DBError { + message: "Unknown column".to_string(), + }) + } + } + + /// Return true if some key exists in some column. + fn exists(&self, col: &str, key: &[u8]) -> Result { + // Panic if the DB locks are poisoned. + let db = self.db.read().unwrap(); + let known_columns = self.known_columns.read().unwrap(); + + if known_columns.contains(&col.to_string()) { + let column_key = MemoryDB::get_key_for_col(col, key); + Ok(db.contains_key(&column_key)) + } else { + Err(DBError { + message: "Unknown column".to_string(), + }) + } + } + + /// Delete some key from the database. + fn delete(&self, col: &str, key: &[u8]) -> Result<(), DBError> { + // Panic if the DB locks are poisoned. + let mut db = self.db.write().unwrap(); + let known_columns = self.known_columns.read().unwrap(); + + if known_columns.contains(&col.to_string()) { + let column_key = MemoryDB::get_key_for_col(col, key); + db.remove(&column_key); + Ok(()) + } else { + Err(DBError { + message: "Unknown column".to_string(), + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::super::stores::{BLOCKS_DB_COLUMN, VALIDATOR_DB_COLUMN}; + use super::super::ClientDB; + use super::*; + use std::sync::Arc; + use std::thread; + + #[test] + fn test_memorydb_can_delete() { + let col_a: &str = BLOCKS_DB_COLUMN; + + let db = MemoryDB::open(); + + db.put(col_a, "dogs".as_bytes(), "lol".as_bytes()).unwrap(); + + assert_eq!( + db.get(col_a, "dogs".as_bytes()).unwrap().unwrap(), + "lol".as_bytes() + ); + + db.delete(col_a, "dogs".as_bytes()).unwrap(); + + assert_eq!(db.get(col_a, "dogs".as_bytes()).unwrap(), None); + } + + #[test] + fn test_memorydb_column_access() { + let col_a: &str = BLOCKS_DB_COLUMN; + let col_b: &str = VALIDATOR_DB_COLUMN; + + let db = MemoryDB::open(); + + /* + * Testing that if we write to the same key in different columns that + * there is not an overlap. + */ + db.put(col_a, "same".as_bytes(), "cat".as_bytes()).unwrap(); + db.put(col_b, "same".as_bytes(), "dog".as_bytes()).unwrap(); + + assert_eq!( + db.get(col_a, "same".as_bytes()).unwrap().unwrap(), + "cat".as_bytes() + ); + assert_eq!( + db.get(col_b, "same".as_bytes()).unwrap().unwrap(), + "dog".as_bytes() + ); + } + + #[test] + fn test_memorydb_unknown_column_access() { + let col_a: &str = BLOCKS_DB_COLUMN; + let col_x: &str = "ColumnX"; + + let db = MemoryDB::open(); + + /* + * Test that we get errors when using undeclared columns + */ + assert!(db.put(col_a, "cats".as_bytes(), "lol".as_bytes()).is_ok()); + assert!(db.put(col_x, "cats".as_bytes(), "lol".as_bytes()).is_err()); + + assert!(db.get(col_a, "cats".as_bytes()).is_ok()); + assert!(db.get(col_x, "cats".as_bytes()).is_err()); + } + + #[test] + fn test_memorydb_exists() { + let col_a: &str = BLOCKS_DB_COLUMN; + let col_b: &str = VALIDATOR_DB_COLUMN; + + let db = MemoryDB::open(); + + /* + * Testing that if we write to the same key in different columns that + * there is not an overlap. + */ + db.put(col_a, "cats".as_bytes(), "lol".as_bytes()).unwrap(); + + assert_eq!(true, db.exists(col_a, "cats".as_bytes()).unwrap()); + assert_eq!(false, db.exists(col_b, "cats".as_bytes()).unwrap()); + + assert_eq!(false, db.exists(col_a, "dogs".as_bytes()).unwrap()); + assert_eq!(false, db.exists(col_b, "dogs".as_bytes()).unwrap()); + } + + #[test] + fn test_memorydb_threading() { + let col_name: &str = BLOCKS_DB_COLUMN; + + let db = Arc::new(MemoryDB::open()); + + let thread_count = 10; + let write_count = 10; + + // We're execting the product of these numbers to fit in one byte. + assert!(thread_count * write_count <= 255); + + let mut handles = vec![]; + for t in 0..thread_count { + let wc = write_count; + let db = db.clone(); + let col = col_name.clone(); + let handle = thread::spawn(move || { + for w in 0..wc { + let key = (t * w) as u8; + let val = 42; + db.put(&col, &vec![key], &vec![val]).unwrap(); + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + + for t in 0..thread_count { + for w in 0..write_count { + let key = (t * w) as u8; + let val = db.get(&col_name, &vec![key]).unwrap().unwrap(); + assert_eq!(vec![42], val); + } + } + } +} diff --git a/beacon_node/db2/src/stores/beacon_block_store.rs b/beacon_node/db2/src/stores/beacon_block_store.rs new file mode 100644 index 000000000..e2e16e60b --- /dev/null +++ b/beacon_node/db2/src/stores/beacon_block_store.rs @@ -0,0 +1,246 @@ +use super::BLOCKS_DB_COLUMN as DB_COLUMN; +use super::{ClientDB, DBError}; +use ssz::Decodable; +use std::sync::Arc; +use types::{BeaconBlock, Hash256, Slot}; + +#[derive(Clone, Debug, PartialEq)] +pub enum BeaconBlockAtSlotError { + UnknownBeaconBlock(Hash256), + InvalidBeaconBlock(Hash256), + DBError(String), +} + +pub struct BeaconBlockStore +where + T: ClientDB, +{ + db: Arc, +} + +// Implements `put`, `get`, `exists` and `delete` for the store. +impl_crud_for_store!(BeaconBlockStore, DB_COLUMN); + +impl BeaconBlockStore { + pub fn new(db: Arc) -> Self { + Self { db } + } + + pub fn get_deserialized(&self, hash: &Hash256) -> Result, DBError> { + match self.get(&hash)? { + None => Ok(None), + Some(ssz) => { + let (block, _) = BeaconBlock::ssz_decode(&ssz, 0).map_err(|_| DBError { + message: "Bad BeaconBlock SSZ.".to_string(), + })?; + Ok(Some(block)) + } + } + } + + /// Retrieve the block at a slot given a "head_hash" and a slot. + /// + /// A "head_hash" must be a block hash with a slot number greater than or equal to the desired + /// slot. + /// + /// This function will read each block down the chain until it finds a block with the given + /// slot number. If the slot is skipped, the function will return None. + /// + /// If a block is found, a tuple of (block_hash, serialized_block) is returned. + /// + /// Note: this function uses a loop instead of recursion as the compiler is over-strict when it + /// comes to recursion and the `impl Trait` pattern. See: + /// https://stackoverflow.com/questions/54032940/using-impl-trait-in-a-recursive-function + pub fn block_at_slot( + &self, + head_hash: &Hash256, + slot: Slot, + ) -> Result, BeaconBlockAtSlotError> { + let mut current_hash = *head_hash; + + loop { + if let Some(block) = self.get_deserialized(¤t_hash)? { + if block.slot == slot { + break Ok(Some((current_hash, block))); + } else if block.slot < slot { + break Ok(None); + } else { + current_hash = block.previous_block_root; + } + } else { + break Err(BeaconBlockAtSlotError::UnknownBeaconBlock(current_hash)); + } + } + } +} + +impl From for BeaconBlockAtSlotError { + fn from(e: DBError) -> Self { + BeaconBlockAtSlotError::DBError(e.message) + } +} + +#[cfg(test)] +mod tests { + use super::super::super::MemoryDB; + use super::*; + + use std::sync::Arc; + use std::thread; + + use ssz::ssz_encode; + use types::test_utils::{SeedableRng, TestRandom, XorShiftRng}; + use types::BeaconBlock; + use types::Hash256; + + test_crud_for_store!(BeaconBlockStore, DB_COLUMN); + + #[test] + fn head_hash_slot_too_low() { + let db = Arc::new(MemoryDB::open()); + let bs = Arc::new(BeaconBlockStore::new(db.clone())); + let mut rng = XorShiftRng::from_seed([42; 16]); + + let mut block = BeaconBlock::random_for_test(&mut rng); + block.slot = Slot::from(10_u64); + + let block_root = block.canonical_root(); + bs.put(&block_root, &ssz_encode(&block)).unwrap(); + + let result = bs.block_at_slot(&block_root, Slot::from(11_u64)).unwrap(); + assert_eq!(result, None); + } + + #[test] + fn test_invalid_block_at_slot() { + let db = Arc::new(MemoryDB::open()); + let store = BeaconBlockStore::new(db.clone()); + + let ssz = "definitly not a valid block".as_bytes(); + let hash = &Hash256::from([0xAA; 32]); + + db.put(DB_COLUMN, hash.as_bytes(), ssz).unwrap(); + assert_eq!( + store.block_at_slot(hash, Slot::from(42_u64)), + Err(BeaconBlockAtSlotError::DBError( + "Bad BeaconBlock SSZ.".into() + )) + ); + } + + #[test] + fn test_unknown_block_at_slot() { + let db = Arc::new(MemoryDB::open()); + let store = BeaconBlockStore::new(db.clone()); + + let ssz = "some bytes".as_bytes(); + let hash = &Hash256::from([0xAA; 32]); + let other_hash = &Hash256::from([0xBB; 32]); + + db.put(DB_COLUMN, hash.as_bytes(), ssz).unwrap(); + assert_eq!( + store.block_at_slot(other_hash, Slot::from(42_u64)), + Err(BeaconBlockAtSlotError::UnknownBeaconBlock(*other_hash)) + ); + } + + #[test] + fn test_block_store_on_memory_db() { + let db = Arc::new(MemoryDB::open()); + let bs = Arc::new(BeaconBlockStore::new(db.clone())); + + let thread_count = 10; + let write_count = 10; + + let mut handles = vec![]; + for t in 0..thread_count { + let wc = write_count; + let bs = bs.clone(); + let handle = thread::spawn(move || { + for w in 0..wc { + let key = t * w; + let val = 42; + bs.put(&Hash256::from_low_u64_le(key), &vec![val]).unwrap(); + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + + for t in 0..thread_count { + for w in 0..write_count { + let key = t * w; + assert!(bs.exists(&Hash256::from_low_u64_le(key)).unwrap()); + let val = bs.get(&Hash256::from_low_u64_le(key)).unwrap().unwrap(); + assert_eq!(vec![42], val); + } + } + } + + #[test] + #[ignore] + fn test_block_at_slot() { + let db = Arc::new(MemoryDB::open()); + let bs = Arc::new(BeaconBlockStore::new(db.clone())); + let mut rng = XorShiftRng::from_seed([42; 16]); + + // Specify test block parameters. + let hashes = [ + Hash256::from([0; 32]), + Hash256::from([1; 32]), + Hash256::from([2; 32]), + Hash256::from([3; 32]), + Hash256::from([4; 32]), + ]; + let parent_hashes = [ + Hash256::from([255; 32]), // Genesis block. + Hash256::from([0; 32]), + Hash256::from([1; 32]), + Hash256::from([2; 32]), + Hash256::from([3; 32]), + ]; + let unknown_hash = Hash256::from([101; 32]); // different from all above + let slots: Vec = vec![0, 1, 3, 4, 5].iter().map(|x| Slot::new(*x)).collect(); + + // Generate a vec of random blocks and store them in the DB. + let block_count = 5; + let mut blocks: Vec = Vec::with_capacity(5); + for i in 0..block_count { + let mut block = BeaconBlock::random_for_test(&mut rng); + + block.previous_block_root = parent_hashes[i]; + block.slot = slots[i]; + + let ssz = ssz_encode(&block); + db.put(DB_COLUMN, hashes[i].as_bytes(), &ssz).unwrap(); + + blocks.push(block); + } + + // Test that certain slots can be reached from certain hashes. + let test_cases = vec![(4, 4), (4, 3), (4, 2), (4, 1), (4, 0)]; + for (hashes_index, slot_index) in test_cases { + let (matched_block_hash, block) = bs + .block_at_slot(&hashes[hashes_index], slots[slot_index]) + .unwrap() + .unwrap(); + assert_eq!(matched_block_hash, hashes[slot_index]); + assert_eq!(block.slot, slots[slot_index]); + } + + let ssz = bs.block_at_slot(&hashes[4], Slot::new(2)).unwrap(); + assert_eq!(ssz, None); + + let ssz = bs.block_at_slot(&hashes[4], Slot::new(6)).unwrap(); + assert_eq!(ssz, None); + + let ssz = bs.block_at_slot(&unknown_hash, Slot::new(2)); + assert_eq!( + ssz, + Err(BeaconBlockAtSlotError::UnknownBeaconBlock(unknown_hash)) + ); + } +} diff --git a/beacon_node/db2/src/stores/beacon_state_store.rs b/beacon_node/db2/src/stores/beacon_state_store.rs new file mode 100644 index 000000000..fd6ff569a --- /dev/null +++ b/beacon_node/db2/src/stores/beacon_state_store.rs @@ -0,0 +1,62 @@ +use super::STATES_DB_COLUMN as DB_COLUMN; +use super::{ClientDB, DBError}; +use ssz::Decodable; +use std::sync::Arc; +use types::{BeaconState, Hash256}; + +pub struct BeaconStateStore +where + T: ClientDB, +{ + db: Arc, +} + +// Implements `put`, `get`, `exists` and `delete` for the store. +impl_crud_for_store!(BeaconStateStore, DB_COLUMN); + +impl BeaconStateStore { + pub fn new(db: Arc) -> Self { + Self { db } + } + + pub fn get_deserialized(&self, hash: &Hash256) -> Result, DBError> { + match self.get(&hash)? { + None => Ok(None), + Some(ssz) => { + let (state, _) = BeaconState::ssz_decode(&ssz, 0).map_err(|_| DBError { + message: "Bad State SSZ.".to_string(), + })?; + Ok(Some(state)) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::super::super::MemoryDB; + use super::*; + + use ssz::ssz_encode; + use std::sync::Arc; + use types::test_utils::{SeedableRng, TestRandom, XorShiftRng}; + use types::Hash256; + + test_crud_for_store!(BeaconStateStore, DB_COLUMN); + + #[test] + fn test_reader() { + let db = Arc::new(MemoryDB::open()); + let store = BeaconStateStore::new(db.clone()); + + let mut rng = XorShiftRng::from_seed([42; 16]); + let state = BeaconState::random_for_test(&mut rng); + let state_root = state.canonical_root(); + + store.put(&state_root, &ssz_encode(&state)).unwrap(); + + let decoded = store.get_deserialized(&state_root).unwrap().unwrap(); + + assert_eq!(state, decoded); + } +} diff --git a/beacon_node/db2/src/stores/macros.rs b/beacon_node/db2/src/stores/macros.rs new file mode 100644 index 000000000..6c53e40ee --- /dev/null +++ b/beacon_node/db2/src/stores/macros.rs @@ -0,0 +1,103 @@ +macro_rules! impl_crud_for_store { + ($store: ident, $db_column: expr) => { + impl $store { + pub fn put(&self, hash: &Hash256, ssz: &[u8]) -> Result<(), DBError> { + self.db.put($db_column, hash.as_bytes(), ssz) + } + + pub fn get(&self, hash: &Hash256) -> Result>, DBError> { + self.db.get($db_column, hash.as_bytes()) + } + + pub fn exists(&self, hash: &Hash256) -> Result { + self.db.exists($db_column, hash.as_bytes()) + } + + pub fn delete(&self, hash: &Hash256) -> Result<(), DBError> { + self.db.delete($db_column, hash.as_bytes()) + } + } + }; +} + +#[cfg(test)] +macro_rules! test_crud_for_store { + ($store: ident, $db_column: expr) => { + #[test] + fn test_put() { + let db = Arc::new(MemoryDB::open()); + let store = $store::new(db.clone()); + + let ssz = "some bytes".as_bytes(); + let hash = &Hash256::from([0xAA; 32]); + + store.put(hash, ssz).unwrap(); + assert_eq!(db.get(DB_COLUMN, hash.as_bytes()).unwrap().unwrap(), ssz); + } + + #[test] + fn test_get() { + let db = Arc::new(MemoryDB::open()); + let store = $store::new(db.clone()); + + let ssz = "some bytes".as_bytes(); + let hash = &Hash256::from([0xAA; 32]); + + db.put(DB_COLUMN, hash.as_bytes(), ssz).unwrap(); + assert_eq!(store.get(hash).unwrap().unwrap(), ssz); + } + + #[test] + fn test_get_unknown() { + let db = Arc::new(MemoryDB::open()); + let store = $store::new(db.clone()); + + let ssz = "some bytes".as_bytes(); + let hash = &Hash256::from([0xAA; 32]); + let other_hash = &Hash256::from([0xBB; 32]); + + db.put(DB_COLUMN, other_hash.as_bytes(), ssz).unwrap(); + assert_eq!(store.get(hash).unwrap(), None); + } + + #[test] + fn test_exists() { + let db = Arc::new(MemoryDB::open()); + let store = $store::new(db.clone()); + + let ssz = "some bytes".as_bytes(); + let hash = &Hash256::from([0xAA; 32]); + + db.put(DB_COLUMN, hash.as_bytes(), ssz).unwrap(); + assert!(store.exists(hash).unwrap()); + } + + #[test] + fn test_block_does_not_exist() { + let db = Arc::new(MemoryDB::open()); + let store = $store::new(db.clone()); + + let ssz = "some bytes".as_bytes(); + let hash = &Hash256::from([0xAA; 32]); + let other_hash = &Hash256::from([0xBB; 32]); + + db.put(DB_COLUMN, hash.as_bytes(), ssz).unwrap(); + assert!(!store.exists(other_hash).unwrap()); + } + + #[test] + fn test_delete() { + let db = Arc::new(MemoryDB::open()); + let store = $store::new(db.clone()); + + let ssz = "some bytes".as_bytes(); + let hash = &Hash256::from([0xAA; 32]); + + db.put(DB_COLUMN, hash.as_bytes(), ssz).unwrap(); + assert!(db.exists(DB_COLUMN, hash.as_bytes()).unwrap()); + + store.delete(hash).unwrap(); + assert!(!db.exists(DB_COLUMN, hash.as_bytes()).unwrap()); + } + }; +} diff --git a/beacon_node/db2/src/stores/mod.rs b/beacon_node/db2/src/stores/mod.rs new file mode 100644 index 000000000..44de7eed1 --- /dev/null +++ b/beacon_node/db2/src/stores/mod.rs @@ -0,0 +1,25 @@ +use super::{ClientDB, DBError}; + +#[macro_use] +mod macros; +mod beacon_block_store; +mod beacon_state_store; +mod pow_chain_store; +mod validator_store; + +pub use self::beacon_block_store::{BeaconBlockAtSlotError, BeaconBlockStore}; +pub use self::beacon_state_store::BeaconStateStore; +pub use self::pow_chain_store::PoWChainStore; +pub use self::validator_store::{ValidatorStore, ValidatorStoreError}; + +pub const BLOCKS_DB_COLUMN: &str = "blocks"; +pub const STATES_DB_COLUMN: &str = "states"; +pub const POW_CHAIN_DB_COLUMN: &str = "powchain"; +pub const VALIDATOR_DB_COLUMN: &str = "validator"; + +pub const COLUMNS: [&str; 4] = [ + BLOCKS_DB_COLUMN, + STATES_DB_COLUMN, + POW_CHAIN_DB_COLUMN, + VALIDATOR_DB_COLUMN, +]; diff --git a/beacon_node/db2/src/stores/pow_chain_store.rs b/beacon_node/db2/src/stores/pow_chain_store.rs new file mode 100644 index 000000000..5c8b97907 --- /dev/null +++ b/beacon_node/db2/src/stores/pow_chain_store.rs @@ -0,0 +1,68 @@ +use super::POW_CHAIN_DB_COLUMN as DB_COLUMN; +use super::{ClientDB, DBError}; +use std::sync::Arc; + +pub struct PoWChainStore +where + T: ClientDB, +{ + db: Arc, +} + +impl PoWChainStore { + pub fn new(db: Arc) -> Self { + Self { db } + } + + pub fn put_block_hash(&self, hash: &[u8]) -> Result<(), DBError> { + self.db.put(DB_COLUMN, hash, &[0]) + } + + pub fn block_hash_exists(&self, hash: &[u8]) -> Result { + self.db.exists(DB_COLUMN, hash) + } +} + +#[cfg(test)] +mod tests { + extern crate types; + + use super::super::super::MemoryDB; + use super::*; + + use self::types::Hash256; + + #[test] + fn test_put_block_hash() { + let db = Arc::new(MemoryDB::open()); + let store = PoWChainStore::new(db.clone()); + + let hash = &Hash256::from([0xAA; 32]).as_bytes().to_vec(); + store.put_block_hash(hash).unwrap(); + + assert!(db.exists(DB_COLUMN, hash).unwrap()); + } + + #[test] + fn test_block_hash_exists() { + let db = Arc::new(MemoryDB::open()); + let store = PoWChainStore::new(db.clone()); + + let hash = &Hash256::from([0xAA; 32]).as_bytes().to_vec(); + db.put(DB_COLUMN, hash, &[0]).unwrap(); + + assert!(store.block_hash_exists(hash).unwrap()); + } + + #[test] + fn test_block_hash_does_not_exist() { + let db = Arc::new(MemoryDB::open()); + let store = PoWChainStore::new(db.clone()); + + let hash = &Hash256::from([0xAA; 32]).as_bytes().to_vec(); + let other_hash = &Hash256::from([0xBB; 32]).as_bytes().to_vec(); + db.put(DB_COLUMN, hash, &[0]).unwrap(); + + assert!(!store.block_hash_exists(other_hash).unwrap()); + } +} diff --git a/beacon_node/db2/src/stores/validator_store.rs b/beacon_node/db2/src/stores/validator_store.rs new file mode 100644 index 000000000..02e90dc5c --- /dev/null +++ b/beacon_node/db2/src/stores/validator_store.rs @@ -0,0 +1,215 @@ +extern crate bytes; + +use self::bytes::{BufMut, BytesMut}; +use super::VALIDATOR_DB_COLUMN as DB_COLUMN; +use super::{ClientDB, DBError}; +use bls::PublicKey; +use ssz::{ssz_encode, Decodable}; +use std::sync::Arc; + +#[derive(Debug, PartialEq)] +pub enum ValidatorStoreError { + DBError(String), + DecodeError, +} + +impl From for ValidatorStoreError { + fn from(error: DBError) -> Self { + ValidatorStoreError::DBError(error.message) + } +} + +#[derive(Debug, PartialEq)] +enum KeyPrefixes { + PublicKey, +} + +pub struct ValidatorStore +where + T: ClientDB, +{ + db: Arc, +} + +impl ValidatorStore { + pub fn new(db: Arc) -> Self { + Self { db } + } + + fn prefix_bytes(&self, key_prefix: &KeyPrefixes) -> Vec { + match key_prefix { + KeyPrefixes::PublicKey => b"pubkey".to_vec(), + } + } + + fn get_db_key_for_index(&self, key_prefix: &KeyPrefixes, index: usize) -> Vec { + let mut buf = BytesMut::with_capacity(6 + 8); + buf.put(self.prefix_bytes(key_prefix)); + buf.put_u64_be(index as u64); + buf.take().to_vec() + } + + pub fn put_public_key_by_index( + &self, + index: usize, + public_key: &PublicKey, + ) -> Result<(), ValidatorStoreError> { + let key = self.get_db_key_for_index(&KeyPrefixes::PublicKey, index); + let val = ssz_encode(public_key); + self.db + .put(DB_COLUMN, &key[..], &val[..]) + .map_err(ValidatorStoreError::from) + } + + pub fn get_public_key_by_index( + &self, + index: usize, + ) -> Result, ValidatorStoreError> { + let key = self.get_db_key_for_index(&KeyPrefixes::PublicKey, index); + let val = self.db.get(DB_COLUMN, &key[..])?; + match val { + None => Ok(None), + Some(val) => match PublicKey::ssz_decode(&val, 0) { + Ok((key, _)) => Ok(Some(key)), + Err(_) => Err(ValidatorStoreError::DecodeError), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::super::super::MemoryDB; + use super::*; + use bls::Keypair; + + #[test] + fn test_prefix_bytes() { + let db = Arc::new(MemoryDB::open()); + let store = ValidatorStore::new(db.clone()); + + assert_eq!( + store.prefix_bytes(&KeyPrefixes::PublicKey), + b"pubkey".to_vec() + ); + } + + #[test] + fn test_get_db_key_for_index() { + let db = Arc::new(MemoryDB::open()); + let store = ValidatorStore::new(db.clone()); + + let mut buf = BytesMut::with_capacity(6 + 8); + buf.put(b"pubkey".to_vec()); + buf.put_u64_be(42); + assert_eq!( + store.get_db_key_for_index(&KeyPrefixes::PublicKey, 42), + buf.take().to_vec() + ) + } + + #[test] + fn test_put_public_key_by_index() { + let db = Arc::new(MemoryDB::open()); + let store = ValidatorStore::new(db.clone()); + + let index = 3; + let public_key = Keypair::random().pk; + + store.put_public_key_by_index(index, &public_key).unwrap(); + let public_key_at_index = db + .get( + DB_COLUMN, + &store.get_db_key_for_index(&KeyPrefixes::PublicKey, index)[..], + ) + .unwrap() + .unwrap(); + + assert_eq!(public_key_at_index, ssz_encode(&public_key)); + } + + #[test] + fn test_get_public_key_by_index() { + let db = Arc::new(MemoryDB::open()); + let store = ValidatorStore::new(db.clone()); + + let index = 4; + let public_key = Keypair::random().pk; + + db.put( + DB_COLUMN, + &store.get_db_key_for_index(&KeyPrefixes::PublicKey, index)[..], + &ssz_encode(&public_key)[..], + ) + .unwrap(); + + let public_key_at_index = store.get_public_key_by_index(index).unwrap().unwrap(); + assert_eq!(public_key_at_index, public_key); + } + + #[test] + fn test_get_public_key_by_unknown_index() { + let db = Arc::new(MemoryDB::open()); + let store = ValidatorStore::new(db.clone()); + + let public_key = Keypair::random().pk; + + db.put( + DB_COLUMN, + &store.get_db_key_for_index(&KeyPrefixes::PublicKey, 3)[..], + &ssz_encode(&public_key)[..], + ) + .unwrap(); + + let public_key_at_index = store.get_public_key_by_index(4).unwrap(); + assert_eq!(public_key_at_index, None); + } + + #[test] + fn test_get_invalid_public_key() { + let db = Arc::new(MemoryDB::open()); + let store = ValidatorStore::new(db.clone()); + + let key = store.get_db_key_for_index(&KeyPrefixes::PublicKey, 42); + db.put(DB_COLUMN, &key[..], "cats".as_bytes()).unwrap(); + + assert_eq!( + store.get_public_key_by_index(42), + Err(ValidatorStoreError::DecodeError) + ); + } + + #[test] + fn test_validator_store_put_get() { + let db = Arc::new(MemoryDB::open()); + let store = ValidatorStore::new(db); + + let keys = vec![ + Keypair::random(), + Keypair::random(), + Keypair::random(), + Keypair::random(), + Keypair::random(), + ]; + + for i in 0..keys.len() { + store.put_public_key_by_index(i, &keys[i].pk).unwrap(); + } + + /* + * Check all keys are retrieved correctly. + */ + for i in 0..keys.len() { + let retrieved = store.get_public_key_by_index(i).unwrap().unwrap(); + assert_eq!(retrieved, keys[i].pk); + } + + /* + * Check that an index that wasn't stored returns None. + */ + assert!(store + .get_public_key_by_index(keys.len() + 1) + .unwrap() + .is_none()); + } +} diff --git a/beacon_node/db2/src/traits.rs b/beacon_node/db2/src/traits.rs new file mode 100644 index 000000000..57ebf9353 --- /dev/null +++ b/beacon_node/db2/src/traits.rs @@ -0,0 +1,38 @@ +pub type DBValue = Vec; + +#[derive(Debug)] +pub struct DBError { + pub message: String, +} + +impl DBError { + pub fn new(message: String) -> Self { + Self { message } + } +} + +/// A generic database to be used by the "client' (i.e., +/// the lighthouse blockchain client). +/// +/// The purpose of having this generic trait is to allow the +/// program to use a persistent on-disk database during production, +/// but use a transient database during tests. +pub trait ClientDB: Sync + Send { + fn get(&self, col: &str, key: &[u8]) -> Result, DBError>; + + fn put(&self, col: &str, key: &[u8], val: &[u8]) -> Result<(), DBError>; + + fn exists(&self, col: &str, key: &[u8]) -> Result; + + fn delete(&self, col: &str, key: &[u8]) -> Result<(), DBError>; +} + +pub enum DBColumn { + Block, + State, + BeaconChain, +} + +pub trait DBStore { + fn db_column(&self) -> DBColumn; +} diff --git a/beacon_node/db_encode/Cargo.toml b/beacon_node/db_encode/Cargo.toml new file mode 100644 index 000000000..b4e919585 --- /dev/null +++ b/beacon_node/db_encode/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "db_encode" +version = "0.1.0" +authors = ["Paul Hauner "] +edition = "2018" + +[dependencies] +ethereum-types = "0.5" +ssz = { path = "../../eth2/utils/ssz" } diff --git a/beacon_node/db_encode/src/lib.rs b/beacon_node/db_encode/src/lib.rs new file mode 100644 index 000000000..993ba0e79 --- /dev/null +++ b/beacon_node/db_encode/src/lib.rs @@ -0,0 +1,59 @@ +use ethereum_types::{Address, H256}; +use ssz::{ssz_encode, Decodable, DecodeError, Encodable, SszStream}; + +/// Convenience function to encode an object. +pub fn db_encode(val: &T) -> Vec +where + T: DBEncode, +{ + let mut ssz_stream = SszStream::new(); + ssz_stream.append(val); + ssz_stream.drain() +} + +/// An encoding scheme based solely upon SSZ. +/// +/// The reason we have a separate encoding scheme is to allows us to store fields in the DB that we +/// don't want to transmit across the wire or hash. +/// +/// For example, the cache fields on `BeaconState` should be stored in the DB, but they should not +/// be hashed or transmitted across the wire. `DBEncode` allows us to define two serialization +/// methods, one that encodes the caches and one that does not. +pub trait DBEncode: Encodable + Sized { + fn db_encode(&self, s: &mut SszStream) { + s.append(&ssz_encode(self)); + } +} + +/// A decoding scheme based solely upon SSZ. +/// +/// See `DBEncode` for reasoning on why this trait exists. +pub trait DBDecode: Decodable { + fn db_decode(bytes: &[u8], index: usize) -> Result<(Self, usize), DecodeError> { + Self::ssz_decode(bytes, index) + } +} + +// Implement encoding. +impl DBEncode for bool {} +impl DBEncode for u8 {} +impl DBEncode for u16 {} +impl DBEncode for u32 {} +impl DBEncode for u64 {} +impl DBEncode for usize {} +impl DBEncode for Vec where T: Encodable + Sized {} + +impl DBEncode for H256 {} +impl DBEncode for Address {} + +// Implement decoding. +impl DBDecode for bool {} +impl DBDecode for u8 {} +impl DBDecode for u16 {} +impl DBDecode for u32 {} +impl DBDecode for u64 {} +impl DBDecode for usize {} +impl DBDecode for Vec where T: Decodable {} + +impl DBDecode for H256 {} +impl DBDecode for Address {} diff --git a/beacon_node/db_encode_derive/Cargo.toml b/beacon_node/db_encode_derive/Cargo.toml new file mode 100644 index 000000000..b2fba85e3 --- /dev/null +++ b/beacon_node/db_encode_derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "db_encode_derive" +version = "0.1.0" +authors = ["Paul Hauner "] +edition = "2018" +description = "Procedural derive macros for `db_encode` encoding and decoding." + +[lib] +proc-macro = true + +[dependencies] +syn = "0.15" +quote = "0.6" diff --git a/beacon_node/db_encode_derive/src/lib.rs b/beacon_node/db_encode_derive/src/lib.rs new file mode 100644 index 000000000..1de081419 --- /dev/null +++ b/beacon_node/db_encode_derive/src/lib.rs @@ -0,0 +1,305 @@ +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, DeriveInput}; + +/// Returns a Vec of `syn::Ident` for each named field in the struct. +/// +/// # Panics +/// Any unnamed struct field (like in a tuple struct) will raise a panic at compile time. +fn get_named_field_idents<'a>(struct_data: &'a syn::DataStruct) -> Vec<&'a syn::Ident> { + struct_data + .fields + .iter() + .map(|f| match &f.ident { + Some(ref ident) => ident, + _ => panic!("db_derive only supports named struct fields."), + }) + .collect() +} + +/// Implements `db_encode::DBEncode` for some `struct`. +/// +/// Fields are encoded in the order they are defined. +#[proc_macro_derive(DBEncode)] +pub fn db_encode_derive(input: TokenStream) -> TokenStream { + let item = parse_macro_input!(input as DeriveInput); + + let name = &item.ident; + + let struct_data = match &item.data { + syn::Data::Struct(s) => s, + _ => panic!("db_derive only supports structs."), + }; + + let field_idents = get_named_field_idents(&struct_data); + + let output = quote! { + impl db_encode::DBEncode for #name { + fn db_encode(&self, s: &mut ssz::SszStream) { + #( + s.append(&self.#field_idents); + )* + } + } + }; + output.into() +} + +/// Implements `db_encode::DBEncode` for some `struct`. +/// +/// Fields are encoded in the order they are defined. +#[proc_macro_derive(DBDecode)] +pub fn db_decode_derive(input: TokenStream) -> TokenStream { + let item = parse_macro_input!(input as DeriveInput); + + let name = &item.ident; + + let struct_data = match &item.data { + syn::Data::Struct(s) => s, + _ => panic!("ssz_derive only supports structs."), + }; + + let field_idents = get_named_field_idents(&struct_data); + + // Using a var in an iteration always consumes the var, therefore we must make a `fields_a` and + // a `fields_b` in order to perform two loops. + // + // https://github.com/dtolnay/quote/issues/8 + let field_idents_a = &field_idents; + let field_idents_b = &field_idents; + + let output = quote! { + impl db_encode::DBDecode for #name { + fn db_decode(bytes: &[u8], i: usize) -> Result<(Self, usize), ssz::DecodeError> { + #( + let (#field_idents_a, i) = <_>::ssz_decode(bytes, i)?; + )* + + Ok(( + Self { + #( + #field_idents_b, + )* + }, + i + )) + } + } + }; + output.into() +} + +/* +/// Returns true if some field has an attribute declaring it should not be deserialized. +/// +/// The field attribute is: `#[ssz(skip_deserializing)]` +fn should_skip_deserializing(field: &syn::Field) -> bool { + for attr in &field.attrs { + if attr.tts.to_string() == "( skip_deserializing )" { + return true; + } + } + false +} + +/// Implements `ssz::Decodable` for some `struct`. +/// +/// Fields are decoded in the order they are defined. +#[proc_macro_derive(Decode)] +pub fn ssz_decode_derive(input: TokenStream) -> TokenStream { + let item = parse_macro_input!(input as DeriveInput); + + let name = &item.ident; + + let struct_data = match &item.data { + syn::Data::Struct(s) => s, + _ => panic!("ssz_derive only supports structs."), + }; + + let all_idents = get_named_field_idents(&struct_data); + + // Build quotes for fields that should be deserialized and those that should be built from + // `Default`. + let mut quotes = vec![]; + for field in &struct_data.fields { + match &field.ident { + Some(ref ident) => { + if should_skip_deserializing(field) { + quotes.push(quote! { + let #ident = <_>::default(); + }); + } else { + quotes.push(quote! { + let (#ident, i) = <_>::ssz_decode(bytes, i)?; + }); + } + } + _ => panic!("ssz_derive only supports named struct fields."), + }; + } + + let output = quote! { + impl ssz::Decodable for #name { + fn ssz_decode(bytes: &[u8], i: usize) -> Result<(Self, usize), ssz::DecodeError> { + #( + #quotes + )* + + Ok(( + Self { + #( + #all_idents, + )* + }, + i + )) + } + } + }; + output.into() +} + +/// Returns a Vec of `syn::Ident` for each named field in the struct, whilst filtering out fields +/// that should not be tree hashed. +/// +/// # Panics +/// Any unnamed struct field (like in a tuple struct) will raise a panic at compile time. +fn get_tree_hashable_named_field_idents<'a>( + struct_data: &'a syn::DataStruct, +) -> Vec<&'a syn::Ident> { + struct_data + .fields + .iter() + .filter_map(|f| { + if should_skip_tree_hash(&f) { + None + } else { + Some(match &f.ident { + Some(ref ident) => ident, + _ => panic!("ssz_derive only supports named struct fields."), + }) + } + }) + .collect() +} + +/// Returns true if some field has an attribute declaring it should not be tree-hashed. +/// +/// The field attribute is: `#[tree_hash(skip_hashing)]` +fn should_skip_tree_hash(field: &syn::Field) -> bool { + for attr in &field.attrs { + if attr.tts.to_string() == "( skip_hashing )" { + return true; + } + } + false +} + +/// Implements `ssz::TreeHash` for some `struct`. +/// +/// Fields are processed in the order they are defined. +#[proc_macro_derive(TreeHash, attributes(tree_hash))] +pub fn ssz_tree_hash_derive(input: TokenStream) -> TokenStream { + let item = parse_macro_input!(input as DeriveInput); + + let name = &item.ident; + + let struct_data = match &item.data { + syn::Data::Struct(s) => s, + _ => panic!("ssz_derive only supports structs."), + }; + + let field_idents = get_tree_hashable_named_field_idents(&struct_data); + + let output = quote! { + impl ssz::TreeHash for #name { + fn hash_tree_root(&self) -> Vec { + let mut list: Vec> = Vec::new(); + #( + list.push(self.#field_idents.hash_tree_root()); + )* + + ssz::merkle_hash(&mut list) + } + } + }; + output.into() +} + +/// Returns `true` if some `Ident` should be considered to be a signature type. +fn type_ident_is_signature(ident: &syn::Ident) -> bool { + match ident.to_string().as_ref() { + "Signature" => true, + "AggregateSignature" => true, + _ => false, + } +} + +/// Takes a `Field` where the type (`ty`) portion is a path (e.g., `types::Signature`) and returns +/// the final `Ident` in that path. +/// +/// E.g., for `types::Signature` returns `Signature`. +fn final_type_ident(field: &syn::Field) -> &syn::Ident { + match &field.ty { + syn::Type::Path(path) => &path.path.segments.last().unwrap().value().ident, + _ => panic!("ssz_derive only supports Path types."), + } +} + +/// Implements `ssz::TreeHash` for some `struct`, whilst excluding any fields following and +/// including a field that is of type "Signature" or "AggregateSignature". +/// +/// See: +/// https://github.com/ethereum/eth2.0-specs/blob/master/specs/simple-serialize.md#signed-roots +/// +/// This is a rather horrendous macro, it will read the type of the object as a string and decide +/// if it's a signature by matching that string against "Signature" or "AggregateSignature". So, +/// it's important that you use those exact words as your type -- don't alias it to something else. +/// +/// If you can think of a better way to do this, please make an issue! +/// +/// Fields are processed in the order they are defined. +#[proc_macro_derive(SignedRoot)] +pub fn ssz_signed_root_derive(input: TokenStream) -> TokenStream { + let item = parse_macro_input!(input as DeriveInput); + + let name = &item.ident; + + let struct_data = match &item.data { + syn::Data::Struct(s) => s, + _ => panic!("ssz_derive only supports structs."), + }; + + let mut field_idents: Vec<&syn::Ident> = vec![]; + + for field in struct_data.fields.iter() { + let final_type_ident = final_type_ident(&field); + + if type_ident_is_signature(final_type_ident) { + break; + } else { + let ident = field + .ident + .as_ref() + .expect("ssz_derive only supports named_struct fields."); + field_idents.push(ident); + } + } + + let output = quote! { + impl ssz::SignedRoot for #name { + fn signed_root(&self) -> Vec { + let mut list: Vec> = Vec::new(); + #( + list.push(self.#field_idents.hash_tree_root()); + )* + + ssz::merkle_hash(&mut list) + } + } + }; + output.into() +} +*/