315 lines
7.5 KiB
Go
315 lines
7.5 KiB
Go
|
// Copyright (c) 2014 Arista Networks, Inc.
|
||
|
// Use of this source code is governed by the Apache License 2.0
|
||
|
// that can be found in the COPYING file.
|
||
|
|
||
|
package test
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"math"
|
||
|
"reflect"
|
||
|
|
||
|
"github.com/aristanetworks/goarista/areflect"
|
||
|
"github.com/aristanetworks/goarista/key"
|
||
|
)
|
||
|
|
||
|
var comparableType = reflect.TypeOf((*key.Comparable)(nil)).Elem()
|
||
|
|
||
|
// DeepEqual is a faster implementation of reflect.DeepEqual that:
|
||
|
// - Has a reflection-free fast-path for all the common types we use.
|
||
|
// - Gives data types the ability to exclude some of their fields from the
|
||
|
// consideration of DeepEqual by tagging them with `deepequal:"ignore"`.
|
||
|
// - Gives data types the ability to define their own comparison method by
|
||
|
// implementing the comparable interface.
|
||
|
// - Supports "composite" (or "complex") keys in maps that are pointers.
|
||
|
func DeepEqual(a, b interface{}) bool {
|
||
|
return deepEqual(a, b, nil)
|
||
|
}
|
||
|
|
||
|
func deepEqual(a, b interface{}, seen map[edge]struct{}) bool {
|
||
|
if a == nil || b == nil {
|
||
|
return a == b
|
||
|
}
|
||
|
switch a := a.(type) {
|
||
|
// Short circuit fast-path for common built-in types.
|
||
|
// Note: the cases are listed by frequency.
|
||
|
case bool:
|
||
|
return a == b
|
||
|
|
||
|
case map[string]interface{}:
|
||
|
v, ok := b.(map[string]interface{})
|
||
|
if !ok || len(a) != len(v) {
|
||
|
return false
|
||
|
}
|
||
|
for key, value := range a {
|
||
|
if other, ok := v[key]; !ok || !deepEqual(value, other, seen) {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
|
||
|
case string, uint32, uint64, int32,
|
||
|
uint16, int16, uint8, int8, int64:
|
||
|
return a == b
|
||
|
|
||
|
case *map[string]interface{}:
|
||
|
v, ok := b.(*map[string]interface{})
|
||
|
if !ok || a == nil || v == nil {
|
||
|
return ok && a == v
|
||
|
}
|
||
|
return deepEqual(*a, *v, seen)
|
||
|
|
||
|
case map[interface{}]interface{}:
|
||
|
v, ok := b.(map[interface{}]interface{})
|
||
|
if !ok {
|
||
|
return false
|
||
|
}
|
||
|
// We compare in both directions to catch keys that are in b but not
|
||
|
// in a. It sucks to have to do another O(N^2) for this, but oh well.
|
||
|
return mapEqual(a, v) && mapEqual(v, a)
|
||
|
|
||
|
case float32:
|
||
|
v, ok := b.(float32)
|
||
|
return ok && (a == b || (math.IsNaN(float64(a)) && math.IsNaN(float64(v))))
|
||
|
case float64:
|
||
|
v, ok := b.(float64)
|
||
|
return ok && (a == b || (math.IsNaN(a) && math.IsNaN(v)))
|
||
|
|
||
|
case []string:
|
||
|
v, ok := b.([]string)
|
||
|
if !ok || len(a) != len(v) {
|
||
|
return false
|
||
|
}
|
||
|
for i, s := range a {
|
||
|
if s != v[i] {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
case []byte:
|
||
|
v, ok := b.([]byte)
|
||
|
return ok && bytes.Equal(a, v)
|
||
|
|
||
|
case map[uint64]interface{}:
|
||
|
v, ok := b.(map[uint64]interface{})
|
||
|
if !ok || len(a) != len(v) {
|
||
|
return false
|
||
|
}
|
||
|
for key, value := range a {
|
||
|
if other, ok := v[key]; !ok || !deepEqual(value, other, seen) {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
|
||
|
case *map[interface{}]interface{}:
|
||
|
v, ok := b.(*map[interface{}]interface{})
|
||
|
if !ok || a == nil || v == nil {
|
||
|
return ok && a == v
|
||
|
}
|
||
|
return deepEqual(*a, *v, seen)
|
||
|
case key.Comparable:
|
||
|
return a.Equal(b)
|
||
|
|
||
|
case []uint32:
|
||
|
v, ok := b.([]uint32)
|
||
|
if !ok || len(a) != len(v) {
|
||
|
return false
|
||
|
}
|
||
|
for i, s := range a {
|
||
|
if s != v[i] {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
case []uint64:
|
||
|
v, ok := b.([]uint64)
|
||
|
if !ok || len(a) != len(v) {
|
||
|
return false
|
||
|
}
|
||
|
for i, s := range a {
|
||
|
if s != v[i] {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
case []interface{}:
|
||
|
v, ok := b.([]interface{})
|
||
|
if !ok || len(a) != len(v) {
|
||
|
return false
|
||
|
}
|
||
|
for i, s := range a {
|
||
|
if !deepEqual(s, v[i], seen) {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
case *[]string:
|
||
|
v, ok := b.(*[]string)
|
||
|
if !ok || a == nil || v == nil {
|
||
|
return ok && a == v
|
||
|
}
|
||
|
return deepEqual(*a, *v, seen)
|
||
|
case *[]interface{}:
|
||
|
v, ok := b.(*[]interface{})
|
||
|
if !ok || a == nil || v == nil {
|
||
|
return ok && a == v
|
||
|
}
|
||
|
return deepEqual(*a, *v, seen)
|
||
|
|
||
|
default:
|
||
|
// Handle other kinds of non-comparable objects.
|
||
|
return genericDeepEqual(a, b, seen)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type edge struct {
|
||
|
from uintptr
|
||
|
to uintptr
|
||
|
}
|
||
|
|
||
|
func genericDeepEqual(a, b interface{}, seen map[edge]struct{}) bool {
|
||
|
av := reflect.ValueOf(a)
|
||
|
bv := reflect.ValueOf(b)
|
||
|
if avalid, bvalid := av.IsValid(), bv.IsValid(); !avalid || !bvalid {
|
||
|
return avalid == bvalid
|
||
|
}
|
||
|
if bv.Type() != av.Type() {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
switch av.Kind() {
|
||
|
case reflect.Ptr:
|
||
|
if av.IsNil() || bv.IsNil() {
|
||
|
return a == b
|
||
|
}
|
||
|
|
||
|
av = av.Elem()
|
||
|
bv = bv.Elem()
|
||
|
if av.CanAddr() && bv.CanAddr() {
|
||
|
e := edge{from: av.UnsafeAddr(), to: bv.UnsafeAddr()}
|
||
|
// Detect and prevent cycles.
|
||
|
if seen == nil {
|
||
|
seen = make(map[edge]struct{})
|
||
|
} else if _, ok := seen[e]; ok {
|
||
|
return true
|
||
|
}
|
||
|
seen[e] = struct{}{}
|
||
|
}
|
||
|
|
||
|
return deepEqual(av.Interface(), bv.Interface(), seen)
|
||
|
case reflect.Slice, reflect.Array:
|
||
|
l := av.Len()
|
||
|
if l != bv.Len() {
|
||
|
return false
|
||
|
}
|
||
|
for i := 0; i < l; i++ {
|
||
|
if !deepEqual(av.Index(i).Interface(), bv.Index(i).Interface(), seen) {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
case reflect.Map:
|
||
|
if av.IsNil() != bv.IsNil() {
|
||
|
return false
|
||
|
}
|
||
|
if av.Len() != bv.Len() {
|
||
|
return false
|
||
|
}
|
||
|
if av.Pointer() == bv.Pointer() {
|
||
|
return true
|
||
|
}
|
||
|
for _, k := range av.MapKeys() {
|
||
|
// Upon finding the first key that's a pointer, we bail out and do
|
||
|
// a O(N^2) comparison.
|
||
|
if kk := k.Kind(); kk == reflect.Ptr || kk == reflect.Interface {
|
||
|
ok, _, _ := complexKeyMapEqual(av, bv, seen)
|
||
|
return ok
|
||
|
}
|
||
|
ea := av.MapIndex(k)
|
||
|
eb := bv.MapIndex(k)
|
||
|
if !eb.IsValid() {
|
||
|
return false
|
||
|
}
|
||
|
if !deepEqual(ea.Interface(), eb.Interface(), seen) {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
case reflect.Struct:
|
||
|
typ := av.Type()
|
||
|
if typ.Implements(comparableType) {
|
||
|
return av.Interface().(key.Comparable).Equal(bv.Interface())
|
||
|
}
|
||
|
for i, n := 0, av.NumField(); i < n; i++ {
|
||
|
if typ.Field(i).Tag.Get("deepequal") == "ignore" {
|
||
|
continue
|
||
|
}
|
||
|
af := areflect.ForceExport(av.Field(i))
|
||
|
bf := areflect.ForceExport(bv.Field(i))
|
||
|
if !deepEqual(af.Interface(), bf.Interface(), seen) {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
default:
|
||
|
// Other the basic types.
|
||
|
return a == b
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Compares two maps with complex keys (that are pointers). This assumes the
|
||
|
// maps have already been checked to have the same sizes. The cost of this
|
||
|
// function is O(N^2) in the size of the input maps.
|
||
|
//
|
||
|
// The return is to be interpreted this way:
|
||
|
// true, _, _ => av == bv
|
||
|
// false, key, invalid => the given key wasn't found in bv
|
||
|
// false, key, value => the given key had the given value in bv,
|
||
|
// which is different in av
|
||
|
func complexKeyMapEqual(av, bv reflect.Value,
|
||
|
seen map[edge]struct{}) (bool, reflect.Value, reflect.Value) {
|
||
|
for _, ka := range av.MapKeys() {
|
||
|
var eb reflect.Value // The entry in bv with a key equal to ka
|
||
|
for _, kb := range bv.MapKeys() {
|
||
|
if deepEqual(ka.Elem().Interface(), kb.Elem().Interface(), seen) {
|
||
|
// Found the corresponding entry in bv.
|
||
|
eb = bv.MapIndex(kb)
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
if !eb.IsValid() { // We didn't find a key equal to `ka' in 'bv'.
|
||
|
return false, ka, reflect.Value{}
|
||
|
}
|
||
|
ea := av.MapIndex(ka)
|
||
|
if !deepEqual(ea.Interface(), eb.Interface(), seen) {
|
||
|
return false, ka, eb
|
||
|
}
|
||
|
}
|
||
|
return true, reflect.Value{}, reflect.Value{}
|
||
|
}
|
||
|
|
||
|
// mapEqual does O(N^2) comparisons to check that all the keys present in the
|
||
|
// first map are also present in the second map and have identical values.
|
||
|
func mapEqual(a, b map[interface{}]interface{}) bool {
|
||
|
if len(a) != len(b) {
|
||
|
return false
|
||
|
}
|
||
|
for akey, avalue := range a {
|
||
|
found := false
|
||
|
for bkey, bvalue := range b {
|
||
|
if DeepEqual(akey, bkey) {
|
||
|
if !DeepEqual(avalue, bvalue) {
|
||
|
return false
|
||
|
}
|
||
|
found = true
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
if !found {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
}
|