From 463e62e83327e885ac887f76b8766e9cde79bc6a Mon Sep 17 00:00:00 2001 From: Michael Sproul Date: Wed, 18 Oct 2023 12:59:53 +0000 Subject: [PATCH] Generalise compare_fields to work with iterators (#4823) ## Proposed Changes Add `compare_fields(as_iter)` as a field attribute to `compare_fields_derive`. This allows any iterable type to be compared in the same as a slice (by index). This is forwards-compatible with tree-states types like `List` and `Vector` which can not be cast to slices. --- Cargo.lock | 1 + common/compare_fields/Cargo.toml | 3 ++ common/compare_fields/src/lib.rs | 38 +++++++++++++++++++------ common/compare_fields_derive/src/lib.rs | 13 +++++---- 4 files changed, 41 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b2ab1e28e..083be58c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1120,6 +1120,7 @@ name = "compare_fields" version = "0.2.0" dependencies = [ "compare_fields_derive", + "itertools", ] [[package]] diff --git a/common/compare_fields/Cargo.toml b/common/compare_fields/Cargo.toml index 8df989e72..9972ca75c 100644 --- a/common/compare_fields/Cargo.toml +++ b/common/compare_fields/Cargo.toml @@ -4,6 +4,9 @@ version = "0.2.0" authors = ["Paul Hauner "] edition = { workspace = true } +[dependencies] +itertools = { workspace = true } + [dev-dependencies] compare_fields_derive = { workspace = true } diff --git a/common/compare_fields/src/lib.rs b/common/compare_fields/src/lib.rs index bc2f5446a..27baf1480 100644 --- a/common/compare_fields/src/lib.rs +++ b/common/compare_fields/src/lib.rs @@ -81,11 +81,8 @@ //! } //! ]; //! assert_eq!(bar_a.compare_fields(&bar_b), bar_a_b); -//! -//! -//! -//! // TODO: //! ``` +use itertools::{EitherOrBoth, Itertools}; use std::fmt::Debug; #[derive(Debug, PartialEq, Clone)] @@ -112,13 +109,38 @@ impl Comparison { } pub fn from_slice>(field_name: String, a: &[T], b: &[T]) -> Self { - let mut children = vec![]; + Self::from_iter(field_name, a.iter(), b.iter()) + } - for i in 0..std::cmp::max(a.len(), b.len()) { - children.push(FieldComparison::new(format!("{i}"), &a.get(i), &b.get(i))); + pub fn from_into_iter<'a, T: Debug + PartialEq + 'a>( + field_name: String, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> Self { + Self::from_iter(field_name, a.into_iter(), b.into_iter()) + } + + pub fn from_iter<'a, T: Debug + PartialEq + 'a>( + field_name: String, + a: impl Iterator, + b: impl Iterator, + ) -> Self { + let mut children = vec![]; + let mut all_equal = true; + + for (i, entry) in a.zip_longest(b).enumerate() { + let comparison = match entry { + EitherOrBoth::Both(x, y) => { + FieldComparison::new(format!("{i}"), &Some(x), &Some(y)) + } + EitherOrBoth::Left(x) => FieldComparison::new(format!("{i}"), &Some(x), &None), + EitherOrBoth::Right(y) => FieldComparison::new(format!("{i}"), &None, &Some(y)), + }; + all_equal = all_equal && comparison.equal(); + children.push(comparison); } - Self::parent(field_name, a == b, children) + Self::parent(field_name, all_equal, children) } pub fn retain_children(&mut self, f: F) diff --git a/common/compare_fields_derive/src/lib.rs b/common/compare_fields_derive/src/lib.rs index a8b92b3d5..099db8e79 100644 --- a/common/compare_fields_derive/src/lib.rs +++ b/common/compare_fields_derive/src/lib.rs @@ -4,10 +4,11 @@ use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, DeriveInput}; -fn is_slice(field: &syn::Field) -> bool { +fn is_iter(field: &syn::Field) -> bool { field.attrs.iter().any(|attr| { attr.path.is_ident("compare_fields") - && attr.tokens.to_string().replace(' ', "") == "(as_slice)" + && (attr.tokens.to_string().replace(' ', "") == "(as_slice)" + || attr.tokens.to_string().replace(' ', "") == "(as_iter)") }) } @@ -34,13 +35,13 @@ pub fn compare_fields_derive(input: TokenStream) -> TokenStream { let field_name = ident_a.to_string(); let ident_b = ident_a.clone(); - let quote = if is_slice(field) { + let quote = if is_iter(field) { quote! { - comparisons.push(compare_fields::Comparison::from_slice( + comparisons.push(compare_fields::Comparison::from_into_iter( #field_name.to_string(), &self.#ident_a, - &b.#ident_b) - ); + &b.#ident_b + )); } } else { quote! {