From 7a99654f8913d4056ba9ffa8ad7911bd232e2244 Mon Sep 17 00:00:00 2001 From: Paul Hauner Date: Wed, 22 May 2019 17:22:12 +1000 Subject: [PATCH] Add new `CompareFields` trait and derive --- Cargo.toml | 2 + eth2/types/Cargo.toml | 2 + eth2/types/src/beacon_state.rs | 5 +- eth2/utils/compare_fields/Cargo.toml | 10 ++++ eth2/utils/compare_fields/src/lib.rs | 11 ++++ eth2/utils/compare_fields/tests/tests.rs | 46 ++++++++++++++++ eth2/utils/compare_fields_derive/Cargo.toml | 12 +++++ eth2/utils/compare_fields_derive/src/lib.rs | 58 +++++++++++++++++++++ 8 files changed, 144 insertions(+), 2 deletions(-) create mode 100644 eth2/utils/compare_fields/Cargo.toml create mode 100644 eth2/utils/compare_fields/src/lib.rs create mode 100644 eth2/utils/compare_fields/tests/tests.rs create mode 100644 eth2/utils/compare_fields_derive/Cargo.toml create mode 100644 eth2/utils/compare_fields_derive/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 704bfde13..1fb9bd1ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,8 @@ members = [ "eth2/utils/bls", "eth2/utils/boolean-bitfield", "eth2/utils/cached_tree_hash", + "eth2/utils/compare_fields", + "eth2/utils/compare_fields_derive", "eth2/utils/fixed_len_vec", "eth2/utils/hashing", "eth2/utils/honey-badger-split", diff --git a/eth2/types/Cargo.toml b/eth2/types/Cargo.toml index 160697edd..fa1fe6a6d 100644 --- a/eth2/types/Cargo.toml +++ b/eth2/types/Cargo.toml @@ -8,6 +8,8 @@ edition = "2018" bls = { path = "../utils/bls" } boolean-bitfield = { path = "../utils/boolean-bitfield" } cached_tree_hash = { path = "../utils/cached_tree_hash" } +compare_fields = { path = "../utils/compare_fields" } +compare_fields_derive = { path = "../utils/compare_fields_derive" } dirs = "1.0" derivative = "1.0" ethereum-types = "0.5" diff --git a/eth2/types/src/beacon_state.rs b/eth2/types/src/beacon_state.rs index a96cfecb8..a9e6c648e 100644 --- a/eth2/types/src/beacon_state.rs +++ b/eth2/types/src/beacon_state.rs @@ -3,11 +3,11 @@ use self::exit_cache::ExitCache; use crate::test_utils::TestRandom; use crate::*; use cached_tree_hash::{Error as TreeHashCacheError, TreeHashCache}; +use compare_fields_derive::CompareFields; +use fixed_len_vec::{typenum::Unsigned, FixedLenVec}; use hashing::hash; use int_to_bytes::{int_to_bytes32, int_to_bytes8}; use pubkey_cache::PubkeyCache; - -use fixed_len_vec::{typenum::Unsigned, FixedLenVec}; use serde_derive::{Deserialize, Serialize}; use ssz::ssz_encode; use ssz_derive::{Decode, Encode}; @@ -74,6 +74,7 @@ pub enum Error { Decode, TreeHash, CachedTreeHash, + CompareFields, )] pub struct BeaconState where diff --git a/eth2/utils/compare_fields/Cargo.toml b/eth2/utils/compare_fields/Cargo.toml new file mode 100644 index 000000000..33826c71d --- /dev/null +++ b/eth2/utils/compare_fields/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "compare_fields" +version = "0.1.0" +authors = ["Paul Hauner "] +edition = "2018" + +[dev-dependencies] +compare_fields_derive = { path = "../compare_fields_derive" } + +[dependencies] diff --git a/eth2/utils/compare_fields/src/lib.rs b/eth2/utils/compare_fields/src/lib.rs new file mode 100644 index 000000000..ce0832dee --- /dev/null +++ b/eth2/utils/compare_fields/src/lib.rs @@ -0,0 +1,11 @@ +#[derive(Debug, PartialEq, Clone)] +pub struct FieldComparison { + pub equal: bool, + pub field_name: String, + pub a: String, + pub b: String, +} + +pub trait CompareFields { + fn compare_fields(&self, b: &Self) -> Vec; +} diff --git a/eth2/utils/compare_fields/tests/tests.rs b/eth2/utils/compare_fields/tests/tests.rs new file mode 100644 index 000000000..96ea94810 --- /dev/null +++ b/eth2/utils/compare_fields/tests/tests.rs @@ -0,0 +1,46 @@ +use compare_fields::{CompareFields, FieldComparison}; +use compare_fields_derive::CompareFields; + +#[derive(Clone, Debug, CompareFields)] +pub struct Simple { + a: u64, + b: u16, + c: Vec, +} + +#[test] +fn compare() { + let foo = Simple { + a: 42, + b: 12, + c: vec![1, 2], + }; + + let mut bar = foo.clone(); + + let comparisons = foo.compare_fields(&bar); + + assert!(!comparisons.iter().any(|c| c.equal == false)); + + assert_eq!( + comparisons[0], + FieldComparison { + equal: true, + field_name: "a".to_string(), + a: "42".to_string(), + b: "42".to_string(), + } + ); + + bar.a = 30; + + assert_eq!( + foo.compare_fields(&bar)[0], + FieldComparison { + equal: false, + field_name: "a".to_string(), + a: "42".to_string(), + b: "30".to_string(), + } + ); +} diff --git a/eth2/utils/compare_fields_derive/Cargo.toml b/eth2/utils/compare_fields_derive/Cargo.toml new file mode 100644 index 000000000..8832e26d3 --- /dev/null +++ b/eth2/utils/compare_fields_derive/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "compare_fields_derive" +version = "0.1.0" +authors = ["Paul Hauner "] +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +syn = "0.15" +quote = "0.6" diff --git a/eth2/utils/compare_fields_derive/src/lib.rs b/eth2/utils/compare_fields_derive/src/lib.rs new file mode 100644 index 000000000..89c61796c --- /dev/null +++ b/eth2/utils/compare_fields_derive/src/lib.rs @@ -0,0 +1,58 @@ +#![recursion_limit = "256"] +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, DeriveInput}; + +#[proc_macro_derive(CompareFields)] +pub fn compare_fields_derive(input: TokenStream) -> TokenStream { + let item = parse_macro_input!(input as DeriveInput); + + let name = &item.ident; + let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); + + let struct_data = match &item.data { + syn::Data::Struct(s) => s, + _ => panic!("compare_fields_derive only supports structs."), + }; + + let mut idents_a = vec![]; + let mut field_names = vec![]; + + for field in struct_data.fields.iter() { + let ident = match &field.ident { + Some(ref ident) => ident, + _ => panic!("compare_fields_derive only supports named struct fields."), + }; + + field_names.push(format!("{:}", ident)); + idents_a.push(ident); + } + + let idents_b = idents_a.clone(); + let idents_c = idents_a.clone(); + let idents_d = idents_a.clone(); + + let output = quote! { + impl #impl_generics compare_fields::CompareFields for #name #ty_generics #where_clause { + fn compare_fields(&self, b: &Self) -> Vec { + let mut comparisons = vec![]; + + #( + comparisons.push( + compare_fields::FieldComparison { + equal: self.#idents_a == b.#idents_b, + field_name: #field_names.to_string(), + a: format!("{:?}", self.#idents_c), + b: format!("{:?}", b.#idents_d), + } + ); + )* + + comparisons + } + } + }; + output.into() +}