common: added Hash unmarshal json length validation

This commit is contained in:
Jeffrey Wilcke 2016-04-01 12:03:06 +02:00
parent 10d3466c93
commit d63e29241d
2 changed files with 35 additions and 0 deletions

View File

@ -19,10 +19,12 @@ package common
import ( import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"math/rand" "math/rand"
"reflect" "reflect"
"strings"
) )
const ( const (
@ -30,6 +32,8 @@ const (
AddressLength = 20 AddressLength = 20
) )
var hashJsonLengthErr = errors.New("common: unmarshalJSON failed: hash must be exactly 32 bytes")
type ( type (
Hash [HashLength]byte Hash [HashLength]byte
Address [AddressLength]byte Address [AddressLength]byte
@ -58,6 +62,15 @@ func (h *Hash) UnmarshalJSON(input []byte) error {
if length >= 2 && input[0] == '"' && input[length-1] == '"' { if length >= 2 && input[0] == '"' && input[length-1] == '"' {
input = input[1 : length-1] input = input[1 : length-1]
} }
// strip "0x" for length check
if len(input) > 1 && strings.ToLower(string(input[:2])) == "0x" {
input = input[2:]
}
// validate the length of the input hash
if len(input) != HashLength*2 {
return hashJsonLengthErr
}
h.SetBytes(FromHex(string(input))) h.SetBytes(FromHex(string(input)))
return nil return nil
} }

View File

@ -29,3 +29,25 @@ func TestBytesConversion(t *testing.T) {
t.Errorf("expected %x got %x", exp, hash) t.Errorf("expected %x got %x", exp, hash)
} }
} }
func TestHashJsonValidation(t *testing.T) {
var h Hash
var tests = []struct {
Prefix string
Size int
Error error
}{
{"", 2, hashJsonLengthErr},
{"", 62, hashJsonLengthErr},
{"", 66, hashJsonLengthErr},
{"", 65, hashJsonLengthErr},
{"0X", 64, nil},
{"0x", 64, nil},
{"0x", 62, hashJsonLengthErr},
}
for i, test := range tests {
if err := h.UnmarshalJSON(append([]byte(test.Prefix), make([]byte, test.Size)...)); err != test.Error {
t.Error(i, "expected", test.Error, "got", err)
}
}
}