diff --git a/common/types.go b/common/types.go index 4d374ad24..71fe5c95c 100644 --- a/common/types.go +++ b/common/types.go @@ -17,6 +17,7 @@ package common import ( + "database/sql/driver" "encoding/hex" "encoding/json" "fmt" @@ -31,7 +32,9 @@ import ( // Lengths of hashes and addresses in bytes. const ( - HashLength = 32 + // HashLength is the expected length of the hash + HashLength = 32 + // AddressLength is the expected length of the adddress AddressLength = 20 ) @@ -120,6 +123,24 @@ func (h Hash) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.ValueOf(h) } +// Scan implements Scanner for database/sql. +func (h *Hash) Scan(src interface{}) error { + srcB, ok := src.([]byte) + if !ok { + return fmt.Errorf("can't scan %T into Hash", src) + } + if len(srcB) != HashLength { + return fmt.Errorf("can't scan []byte of len %d into Hash, want %d", len(srcB), HashLength) + } + copy(h[:], srcB) + return nil +} + +// Value implements valuer for database/sql. +func (h Hash) Value() (driver.Value, error) { + return h[:], nil +} + // UnprefixedHash allows marshaling a Hash without 0x prefix. type UnprefixedHash Hash @@ -229,6 +250,24 @@ func (a *Address) UnmarshalJSON(input []byte) error { return hexutil.UnmarshalFixedJSON(addressT, input, a[:]) } +// Scan implements Scanner for database/sql. +func (a *Address) Scan(src interface{}) error { + srcB, ok := src.([]byte) + if !ok { + return fmt.Errorf("can't scan %T into Address", src) + } + if len(srcB) != AddressLength { + return fmt.Errorf("can't scan []byte of len %d into Address, want %d", len(srcB), AddressLength) + } + copy(a[:], srcB) + return nil +} + +// Value implements valuer for database/sql. +func (a Address) Value() (driver.Value, error) { + return a[:], nil +} + // UnprefixedAddress allows marshaling an Address without 0x prefix. type UnprefixedAddress Address diff --git a/common/types_test.go b/common/types_test.go index 9e0c5be3a..7095ccd01 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -17,9 +17,10 @@ package common import ( + "database/sql/driver" "encoding/json" - "math/big" + "reflect" "strings" "testing" ) @@ -193,3 +194,180 @@ func TestMixedcaseAccount_Address(t *testing.T) { } } + +func TestHash_Scan(t *testing.T) { + type args struct { + src interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "working scan", + args: args{src: []byte{ + 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + 0x10, 0x00, + }}, + wantErr: false, + }, + { + name: "non working scan", + args: args{src: int64(1234567890)}, + wantErr: true, + }, + { + name: "invalid length scan", + args: args{src: []byte{ + 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + }}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Hash{} + if err := h.Scan(tt.args.src); (err != nil) != tt.wantErr { + t.Errorf("Hash.Scan() error = %v, wantErr %v", err, tt.wantErr) + } + + if !tt.wantErr { + for i := range h { + if h[i] != tt.args.src.([]byte)[i] { + t.Errorf( + "Hash.Scan() didn't scan the %d src correctly (have %X, want %X)", + i, h[i], tt.args.src.([]byte)[i], + ) + } + } + } + }) + } +} + +func TestHash_Value(t *testing.T) { + b := []byte{ + 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + 0x10, 0x00, + } + var usedH Hash + usedH.SetBytes(b) + tests := []struct { + name string + h Hash + want driver.Value + wantErr bool + }{ + { + name: "Working value", + h: usedH, + want: b, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.h.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Hash.Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Hash.Value() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAddress_Scan(t *testing.T) { + type args struct { + src interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "working scan", + args: args{src: []byte{ + 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + }}, + wantErr: false, + }, + { + name: "non working scan", + args: args{src: int64(1234567890)}, + wantErr: true, + }, + { + name: "invalid length scan", + args: args{src: []byte{ + 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, + }}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Address{} + if err := a.Scan(tt.args.src); (err != nil) != tt.wantErr { + t.Errorf("Address.Scan() error = %v, wantErr %v", err, tt.wantErr) + } + + if !tt.wantErr { + for i := range a { + if a[i] != tt.args.src.([]byte)[i] { + t.Errorf( + "Address.Scan() didn't scan the %d src correctly (have %X, want %X)", + i, a[i], tt.args.src.([]byte)[i], + ) + } + } + } + }) + } +} + +func TestAddress_Value(t *testing.T) { + b := []byte{ + 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + } + var usedA Address + usedA.SetBytes(b) + tests := []struct { + name string + a Address + want driver.Value + wantErr bool + }{ + { + name: "Working value", + a: usedA, + want: b, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.a.Value() + if (err != nil) != tt.wantErr { + t.Errorf("Address.Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Address.Value() = %v, want %v", got, tt.want) + } + }) + } +}