common: add database/sql support for Hash and Address (#15541)
This commit is contained in:
parent
d96ba77113
commit
2909f6d7a2
@ -17,6 +17,7 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -31,7 +32,9 @@ import (
|
|||||||
|
|
||||||
// Lengths of hashes and addresses in bytes.
|
// Lengths of hashes and addresses in bytes.
|
||||||
const (
|
const (
|
||||||
|
// HashLength is the expected length of the hash
|
||||||
HashLength = 32
|
HashLength = 32
|
||||||
|
// AddressLength is the expected length of the adddress
|
||||||
AddressLength = 20
|
AddressLength = 20
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -120,6 +123,24 @@ func (h Hash) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||||||
return reflect.ValueOf(h)
|
return reflect.ValueOf(h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scan implements Scanner for database/sql.
|
||||||
|
func (h *Hash) Scan(src interface{}) error {
|
||||||
|
srcB, ok := src.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("can't scan %T into Hash", src)
|
||||||
|
}
|
||||||
|
if len(srcB) != HashLength {
|
||||||
|
return fmt.Errorf("can't scan []byte of len %d into Hash, want %d", len(srcB), HashLength)
|
||||||
|
}
|
||||||
|
copy(h[:], srcB)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements valuer for database/sql.
|
||||||
|
func (h Hash) Value() (driver.Value, error) {
|
||||||
|
return h[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
// UnprefixedHash allows marshaling a Hash without 0x prefix.
|
// UnprefixedHash allows marshaling a Hash without 0x prefix.
|
||||||
type UnprefixedHash Hash
|
type UnprefixedHash Hash
|
||||||
|
|
||||||
@ -229,6 +250,24 @@ func (a *Address) UnmarshalJSON(input []byte) error {
|
|||||||
return hexutil.UnmarshalFixedJSON(addressT, input, a[:])
|
return hexutil.UnmarshalFixedJSON(addressT, input, a[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scan implements Scanner for database/sql.
|
||||||
|
func (a *Address) Scan(src interface{}) error {
|
||||||
|
srcB, ok := src.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("can't scan %T into Address", src)
|
||||||
|
}
|
||||||
|
if len(srcB) != AddressLength {
|
||||||
|
return fmt.Errorf("can't scan []byte of len %d into Address, want %d", len(srcB), AddressLength)
|
||||||
|
}
|
||||||
|
copy(a[:], srcB)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements valuer for database/sql.
|
||||||
|
func (a Address) Value() (driver.Value, error) {
|
||||||
|
return a[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
// UnprefixedAddress allows marshaling an Address without 0x prefix.
|
// UnprefixedAddress allows marshaling an Address without 0x prefix.
|
||||||
type UnprefixedAddress Address
|
type UnprefixedAddress Address
|
||||||
|
|
||||||
|
@ -17,9 +17,10 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -193,3 +194,180 @@ func TestMixedcaseAccount_Address(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHash_Scan(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
src interface{}
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "working scan",
|
||||||
|
args: args{src: []byte{
|
||||||
|
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
|
||||||
|
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
|
||||||
|
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
|
||||||
|
0x10, 0x00,
|
||||||
|
}},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non working scan",
|
||||||
|
args: args{src: int64(1234567890)},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid length scan",
|
||||||
|
args: args{src: []byte{
|
||||||
|
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
|
||||||
|
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
|
||||||
|
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
|
||||||
|
}},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := &Hash{}
|
||||||
|
if err := h.Scan(tt.args.src); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Hash.Scan() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantErr {
|
||||||
|
for i := range h {
|
||||||
|
if h[i] != tt.args.src.([]byte)[i] {
|
||||||
|
t.Errorf(
|
||||||
|
"Hash.Scan() didn't scan the %d src correctly (have %X, want %X)",
|
||||||
|
i, h[i], tt.args.src.([]byte)[i],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHash_Value(t *testing.T) {
|
||||||
|
b := []byte{
|
||||||
|
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
|
||||||
|
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
|
||||||
|
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
|
||||||
|
0x10, 0x00,
|
||||||
|
}
|
||||||
|
var usedH Hash
|
||||||
|
usedH.SetBytes(b)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
h Hash
|
||||||
|
want driver.Value
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Working value",
|
||||||
|
h: usedH,
|
||||||
|
want: b,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.h.Value()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Hash.Value() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Hash.Value() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddress_Scan(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
src interface{}
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "working scan",
|
||||||
|
args: args{src: []byte{
|
||||||
|
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
|
||||||
|
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
|
||||||
|
}},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non working scan",
|
||||||
|
args: args{src: int64(1234567890)},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid length scan",
|
||||||
|
args: args{src: []byte{
|
||||||
|
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
|
||||||
|
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a,
|
||||||
|
}},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := &Address{}
|
||||||
|
if err := a.Scan(tt.args.src); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Address.Scan() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantErr {
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != tt.args.src.([]byte)[i] {
|
||||||
|
t.Errorf(
|
||||||
|
"Address.Scan() didn't scan the %d src correctly (have %X, want %X)",
|
||||||
|
i, a[i], tt.args.src.([]byte)[i],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddress_Value(t *testing.T) {
|
||||||
|
b := []byte{
|
||||||
|
0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e,
|
||||||
|
0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15,
|
||||||
|
}
|
||||||
|
var usedA Address
|
||||||
|
usedA.SetBytes(b)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a Address
|
||||||
|
want driver.Value
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Working value",
|
||||||
|
a: usedA,
|
||||||
|
want: b,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.a.Value()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Address.Value() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Address.Value() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user