common: add database/sql support for Hash and Address (#15541)

This commit is contained in:
Vincent Serpoul 2018-07-24 21:15:07 +08:00 committed by Felix Lange
parent d96ba77113
commit 2909f6d7a2
2 changed files with 219 additions and 2 deletions

View File

@ -17,6 +17,7 @@
package common package common
import ( import (
"database/sql/driver"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -31,7 +32,9 @@ import (
// Lengths of hashes and addresses in bytes. // Lengths of hashes and addresses in bytes.
const ( const (
// HashLength is the expected length of the hash
HashLength = 32 HashLength = 32
// AddressLength is the expected length of the adddress
AddressLength = 20 AddressLength = 20
) )
@ -120,6 +123,24 @@ func (h Hash) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(h) return reflect.ValueOf(h)
} }
// Scan implements Scanner for database/sql.
func (h *Hash) Scan(src interface{}) error {
srcB, ok := src.([]byte)
if !ok {
return fmt.Errorf("can't scan %T into Hash", src)
}
if len(srcB) != HashLength {
return fmt.Errorf("can't scan []byte of len %d into Hash, want %d", len(srcB), HashLength)
}
copy(h[:], srcB)
return nil
}
// Value implements valuer for database/sql.
func (h Hash) Value() (driver.Value, error) {
return h[:], nil
}
// UnprefixedHash allows marshaling a Hash without 0x prefix. // UnprefixedHash allows marshaling a Hash without 0x prefix.
type UnprefixedHash Hash type UnprefixedHash Hash
@ -229,6 +250,24 @@ func (a *Address) UnmarshalJSON(input []byte) error {
return hexutil.UnmarshalFixedJSON(addressT, input, a[:]) return hexutil.UnmarshalFixedJSON(addressT, input, a[:])
} }
// Scan implements Scanner for database/sql.
func (a *Address) Scan(src interface{}) error {
srcB, ok := src.([]byte)
if !ok {
return fmt.Errorf("can't scan %T into Address", src)
}
if len(srcB) != AddressLength {
return fmt.Errorf("can't scan []byte of len %d into Address, want %d", len(srcB), AddressLength)
}
copy(a[:], srcB)
return nil
}
// Value implements valuer for database/sql.
func (a Address) Value() (driver.Value, error) {
return a[:], nil
}
// UnprefixedAddress allows marshaling an Address without 0x prefix. // UnprefixedAddress allows marshaling an Address without 0x prefix.
type UnprefixedAddress Address type UnprefixedAddress Address

View File

@ -17,9 +17,10 @@
package common package common
import ( import (
"database/sql/driver"
"encoding/json" "encoding/json"
"math/big" "math/big"
"reflect"
"strings" "strings"
"testing" "testing"
) )
@ -193,3 +194,180 @@ func TestMixedcaseAccount_Address(t *testing.T) {
} }
} }
func TestHash_Scan(t *testing.T) {
type args struct {
src interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "working scan",
args: args{src: []byte{
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
0x10, 0x00,
}},
wantErr: false,
},
{
name: "non working scan",
args: args{src: int64(1234567890)},
wantErr: true,
},
{
name: "invalid length scan",
args: args{src: []byte{
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
}},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &Hash{}
if err := h.Scan(tt.args.src); (err != nil) != tt.wantErr {
t.Errorf("Hash.Scan() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr {
for i := range h {
if h[i] != tt.args.src.([]byte)[i] {
t.Errorf(
"Hash.Scan() didn't scan the %d src correctly (have %X, want %X)",
i, h[i], tt.args.src.([]byte)[i],
)
}
}
}
})
}
}
func TestHash_Value(t *testing.T) {
b := []byte{
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
0x10, 0x00,
}
var usedH Hash
usedH.SetBytes(b)
tests := []struct {
name string
h Hash
want driver.Value
wantErr bool
}{
{
name: "Working value",
h: usedH,
want: b,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.h.Value()
if (err != nil) != tt.wantErr {
t.Errorf("Hash.Value() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Hash.Value() = %v, want %v", got, tt.want)
}
})
}
}
func TestAddress_Scan(t *testing.T) {
type args struct {
src interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "working scan",
args: args{src: []byte{
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
}},
wantErr: false,
},
{
name: "non working scan",
args: args{src: int64(1234567890)},
wantErr: true,
},
{
name: "invalid length scan",
args: args{src: []byte{
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a,
}},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Address{}
if err := a.Scan(tt.args.src); (err != nil) != tt.wantErr {
t.Errorf("Address.Scan() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr {
for i := range a {
if a[i] != tt.args.src.([]byte)[i] {
t.Errorf(
"Address.Scan() didn't scan the %d src correctly (have %X, want %X)",
i, a[i], tt.args.src.([]byte)[i],
)
}
}
}
})
}
}
func TestAddress_Value(t *testing.T) {
b := []byte{
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
}
var usedA Address
usedA.SetBytes(b)
tests := []struct {
name string
a Address
want driver.Value
wantErr bool
}{
{
name: "Working value",
a: usedA,
want: b,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.a.Value()
if (err != nil) != tt.wantErr {
t.Errorf("Address.Value() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Address.Value() = %v, want %v", got, tt.want)
}
})
}
}