rlp: stricter rules for structs and pointers

The rules have changed as follows:

* When decoding into pointers, empty values no longer produce
  a nil pointer. This can be overriden for struct fields using the
  struct tag "nil".
* When decoding into structs, the input list must contain an element
  for each field.
This commit is contained in:
Felix Lange 2015-04-17 01:16:46 +02:00
parent 1e2c93aa2d
commit cad64fb911
4 changed files with 148 additions and 52 deletions

View File

@ -36,17 +36,26 @@ type Decoder interface {
// If the type implements the Decoder interface, decode calls // If the type implements the Decoder interface, decode calls
// DecodeRLP. // DecodeRLP.
// //
// To decode into a pointer, Decode will set the pointer to nil if the // To decode into a pointer, Decode will decode into the value pointed
// input has size zero. If the input has nonzero size, Decode will // to. If the pointer is nil, a new value of the pointer's element
// parse the input data into a value of the type being pointed to. // type is allocated. If the pointer is non-nil, the existing value
// If the pointer is non-nil, the existing value will reused. // will reused.
// //
// To decode into a struct, Decode expects the input to be an RLP // To decode into a struct, Decode expects the input to be an RLP
// list. The decoded elements of the list are assigned to each public // list. The decoded elements of the list are assigned to each public
// field in the order given by the struct's definition. If the input // field in the order given by the struct's definition. The input list
// list has too few elements, no error is returned and the remaining // must contain an element for each decoded field. Decode returns an
// fields will have the zero value. // error if there are too few or too many elements.
// Recursive struct types are supported. //
// The decoding of struct fields honours one particular struct tag,
// "nil". This tag applies to pointer-typed fields and changes the
// decoding rules for the field such that input values of size zero
// decode as a nil pointer. This tag can be useful when decoding recursive
// types.
//
// type StructWithEmptyOK struct {
// Foo *[20]byte `rlp:"nil"`
// }
// //
// To decode into a slice, the input must be a list and the resulting // To decode into a slice, the input must be a list and the resulting
// slice will contain the input elements in order. // slice will contain the input elements in order.
@ -54,7 +63,7 @@ type Decoder interface {
// can also be an RLP string. // can also be an RLP string.
// //
// To decode into a Go string, the input must be an RLP string. The // To decode into a Go string, the input must be an RLP string. The
// bytes are taken as-is and will not necessarily be valid UTF-8. // input bytes are taken as-is and will not necessarily be valid UTF-8.
// //
// To decode into an unsigned integer type, the input must also be an RLP // To decode into an unsigned integer type, the input must also be an RLP
// string. The bytes are interpreted as a big endian representation of // string. The bytes are interpreted as a big endian representation of
@ -65,8 +74,8 @@ type Decoder interface {
// To decode into an interface value, Decode stores one of these // To decode into an interface value, Decode stores one of these
// in the value: // in the value:
// //
// []interface{}, for RLP lists // []interface{}, for RLP lists
// []byte, for RLP strings // []byte, for RLP strings
// //
// Non-empty interface types are not supported, nor are booleans, // Non-empty interface types are not supported, nor are booleans,
// signed integers, floating point numbers, maps, channels and // signed integers, floating point numbers, maps, channels and
@ -136,7 +145,7 @@ var (
bigInt = reflect.TypeOf(big.Int{}) bigInt = reflect.TypeOf(big.Int{})
) )
func makeDecoder(typ reflect.Type) (dec decoder, err error) { func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
kind := typ.Kind() kind := typ.Kind()
switch { switch {
case typ.Implements(decoderInterface): case typ.Implements(decoderInterface):
@ -156,6 +165,9 @@ func makeDecoder(typ reflect.Type) (dec decoder, err error) {
case kind == reflect.Struct: case kind == reflect.Struct:
return makeStructDecoder(typ) return makeStructDecoder(typ)
case kind == reflect.Ptr: case kind == reflect.Ptr:
if tags.nilOK {
return makeOptionalPtrDecoder(typ)
}
return makePtrDecoder(typ) return makePtrDecoder(typ)
case kind == reflect.Interface: case kind == reflect.Interface:
return decodeInterface, nil return decodeInterface, nil
@ -214,7 +226,7 @@ func makeListDecoder(typ reflect.Type) (decoder, error) {
return decodeByteSlice, nil return decodeByteSlice, nil
} }
} }
etypeinfo, err := cachedTypeInfo1(etype) etypeinfo, err := cachedTypeInfo1(etype, tags{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -352,11 +364,6 @@ func zero(val reflect.Value, start int) {
} }
} }
type field struct {
index int
info *typeinfo
}
func makeStructDecoder(typ reflect.Type) (decoder, error) { func makeStructDecoder(typ reflect.Type) (decoder, error) {
fields, err := structFields(typ) fields, err := structFields(typ)
if err != nil { if err != nil {
@ -369,8 +376,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
for _, f := range fields { for _, f := range fields {
err = f.info.decoder(s, val.Field(f.index)) err = f.info.decoder(s, val.Field(f.index))
if err == EOL { if err == EOL {
// too few elements. leave the rest at their zero value. return &decodeError{msg: "too few elements", typ: typ}
break
} else if err != nil { } else if err != nil {
return addErrorContext(err, "."+typ.Field(f.index).Name) return addErrorContext(err, "."+typ.Field(f.index).Name)
} }
@ -380,9 +386,35 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
return dec, nil return dec, nil
} }
// makePtrDecoder creates a decoder that decodes into
// the pointer's element type.
func makePtrDecoder(typ reflect.Type) (decoder, error) { func makePtrDecoder(typ reflect.Type) (decoder, error) {
etype := typ.Elem() etype := typ.Elem()
etypeinfo, err := cachedTypeInfo1(etype) etypeinfo, err := cachedTypeInfo1(etype, tags{})
if err != nil {
return nil, err
}
dec := func(s *Stream, val reflect.Value) (err error) {
newval := val
if val.IsNil() {
newval = reflect.New(etype)
}
if err = etypeinfo.decoder(s, newval.Elem()); err == nil {
val.Set(newval)
}
return err
}
return dec, nil
}
// makeOptionalPtrDecoder creates a decoder that decodes empty values
// as nil. Non-empty values are decoded into a value of the element type,
// just like makePtrDecoder does.
//
// This decoder is used for pointer-typed struct fields with struct tag "nil".
func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) {
etype := typ.Elem()
etypeinfo, err := cachedTypeInfo1(etype, tags{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -706,7 +738,7 @@ func (s *Stream) Decode(val interface{}) error {
if rval.IsNil() { if rval.IsNil() {
return errDecodeIntoNil return errDecodeIntoNil
} }
info, err := cachedTypeInfo(rtyp.Elem()) info, err := cachedTypeInfo(rtyp.Elem(), tags{})
if err != nil { if err != nil {
return err return err
} }

View File

@ -280,7 +280,7 @@ type simplestruct struct {
type recstruct struct { type recstruct struct {
I uint I uint
Child *recstruct Child *recstruct `rlp:"nil"`
} }
var ( var (
@ -390,15 +390,33 @@ var decodeTests = []decodeTest{
{input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"}, {input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"},
// structs // structs
{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},
{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}},
{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}},
{ {
input: "C501C302C103", input: "C50583343434",
ptr: new(simplestruct),
value: simplestruct{5, "444"},
},
{
input: "C601C402C203C0",
ptr: new(recstruct), ptr: new(recstruct),
value: recstruct{1, &recstruct{2, &recstruct{3, nil}}}, value: recstruct{1, &recstruct{2, &recstruct{3, nil}}},
}, },
// struct errors
{
input: "C0",
ptr: new(simplestruct),
error: "rlp: too few elements for rlp.simplestruct",
},
{
input: "C105",
ptr: new(simplestruct),
error: "rlp: too few elements for rlp.simplestruct",
},
{
input: "C7C50583343434C0",
ptr: new([]*simplestruct),
error: "rlp: too few elements for rlp.simplestruct, decoding into ([]*rlp.simplestruct)[1]",
},
{ {
input: "83222222", input: "83222222",
ptr: new(simplestruct), ptr: new(simplestruct),
@ -417,19 +435,15 @@ var decodeTests = []decodeTest{
// pointers // pointers
{input: "00", ptr: new(*[]byte), value: &[]byte{0}}, {input: "00", ptr: new(*[]byte), value: &[]byte{0}},
{input: "80", ptr: new(*uint), value: (*uint)(nil)}, {input: "80", ptr: new(*uint), value: uintp(0)},
{input: "C0", ptr: new(*uint), value: (*uint)(nil)}, {input: "C0", ptr: new(*uint), error: "rlp: expected input string or byte for uint"},
{input: "07", ptr: new(*uint), value: uintp(7)}, {input: "07", ptr: new(*uint), value: uintp(7)},
{input: "8158", ptr: new(*uint), value: uintp(0x58)}, {input: "8158", ptr: new(*uint), value: uintp(0x58)},
{input: "C109", ptr: new(*[]uint), value: &[]uint{9}}, {input: "C109", ptr: new(*[]uint), value: &[]uint{9}},
{input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}}, {input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}},
// check that input position is advanced also for empty values. // check that input position is advanced also for empty values.
{input: "C3808005", ptr: new([]*uint), value: []*uint{nil, nil, uintp(5)}}, {input: "C3808005", ptr: new([]*uint), value: []*uint{uintp(0), uintp(0), uintp(5)}},
// pointer should be reset to nil
{input: "05", ptr: sharedPtr, value: uintp(5)},
{input: "80", ptr: sharedPtr, value: (*uint)(nil)},
// interface{} // interface{}
{input: "00", ptr: new(interface{}), value: []byte{0}}, {input: "00", ptr: new(interface{}), value: []byte{0}},
@ -599,6 +613,33 @@ func ExampleDecode() {
// Decoded value: rlp.example{A:0xa, B:0x14, private:0x0, String:"foobar"} // Decoded value: rlp.example{A:0xa, B:0x14, private:0x0, String:"foobar"}
} }
func ExampleDecode_structTagNil() {
// In this example, we'll use the "nil" struct tag to change
// how a pointer-typed field is decoded. The input contains an RLP
// list of one element, an empty string.
input := []byte{0xC1, 0x80}
// This type uses the normal rules.
// The empty input string is decoded as a pointer to an empty Go string.
var normalRules struct {
String *string
}
Decode(bytes.NewReader(input), &normalRules)
fmt.Printf("normal: String = %q\n", *normalRules.String)
// This type uses the struct tag.
// The empty input string is decoded as a nil pointer.
var withEmptyOK struct {
String *string `rlp:"nil"`
}
Decode(bytes.NewReader(input), &withEmptyOK)
fmt.Printf("with nil tag: String = %v\n", withEmptyOK.String)
// Output:
// normal: String = ""
// with nil tag: String = <nil>
}
func ExampleStream() { func ExampleStream() {
input, _ := hex.DecodeString("C90A1486666F6F626172") input, _ := hex.DecodeString("C90A1486666F6F626172")
s := NewStream(bytes.NewReader(input), 0) s := NewStream(bytes.NewReader(input), 0)

View File

@ -194,7 +194,7 @@ func (w *encbuf) Write(b []byte) (int, error) {
func (w *encbuf) encode(val interface{}) error { func (w *encbuf) encode(val interface{}) error {
rval := reflect.ValueOf(val) rval := reflect.ValueOf(val)
ti, err := cachedTypeInfo(rval.Type()) ti, err := cachedTypeInfo(rval.Type(), tags{})
if err != nil { if err != nil {
return err return err
} }
@ -485,7 +485,7 @@ func writeInterface(val reflect.Value, w *encbuf) error {
return nil return nil
} }
eval := val.Elem() eval := val.Elem()
ti, err := cachedTypeInfo(eval.Type()) ti, err := cachedTypeInfo(eval.Type(), tags{})
if err != nil { if err != nil {
return err return err
} }
@ -493,7 +493,7 @@ func writeInterface(val reflect.Value, w *encbuf) error {
} }
func makeSliceWriter(typ reflect.Type) (writer, error) { func makeSliceWriter(typ reflect.Type) (writer, error) {
etypeinfo, err := cachedTypeInfo1(typ.Elem()) etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -530,7 +530,7 @@ func makeStructWriter(typ reflect.Type) (writer, error) {
} }
func makePtrWriter(typ reflect.Type) (writer, error) { func makePtrWriter(typ reflect.Type) (writer, error) {
etypeinfo, err := cachedTypeInfo1(typ.Elem()) etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -7,7 +7,7 @@ import (
var ( var (
typeCacheMutex sync.RWMutex typeCacheMutex sync.RWMutex
typeCache = make(map[reflect.Type]*typeinfo) typeCache = make(map[typekey]*typeinfo)
) )
type typeinfo struct { type typeinfo struct {
@ -15,13 +15,25 @@ type typeinfo struct {
writer writer
} }
// represents struct tags
type tags struct {
nilOK bool
}
type typekey struct {
reflect.Type
// the key must include the struct tags because they
// might generate a different decoder.
tags
}
type decoder func(*Stream, reflect.Value) error type decoder func(*Stream, reflect.Value) error
type writer func(reflect.Value, *encbuf) error type writer func(reflect.Value, *encbuf) error
func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) { func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) {
typeCacheMutex.RLock() typeCacheMutex.RLock()
info := typeCache[typ] info := typeCache[typekey{typ, tags}]
typeCacheMutex.RUnlock() typeCacheMutex.RUnlock()
if info != nil { if info != nil {
return info, nil return info, nil
@ -29,11 +41,12 @@ func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) {
// not in the cache, need to generate info for this type. // not in the cache, need to generate info for this type.
typeCacheMutex.Lock() typeCacheMutex.Lock()
defer typeCacheMutex.Unlock() defer typeCacheMutex.Unlock()
return cachedTypeInfo1(typ) return cachedTypeInfo1(typ, tags)
} }
func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) { func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) {
info := typeCache[typ] key := typekey{typ, tags}
info := typeCache[key]
if info != nil { if info != nil {
// another goroutine got the write lock first // another goroutine got the write lock first
return info, nil return info, nil
@ -41,21 +54,27 @@ func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) {
// put a dummmy value into the cache before generating. // put a dummmy value into the cache before generating.
// if the generator tries to lookup itself, it will get // if the generator tries to lookup itself, it will get
// the dummy value and won't call itself recursively. // the dummy value and won't call itself recursively.
typeCache[typ] = new(typeinfo) typeCache[key] = new(typeinfo)
info, err := genTypeInfo(typ) info, err := genTypeInfo(typ, tags)
if err != nil { if err != nil {
// remove the dummy value if the generator fails // remove the dummy value if the generator fails
delete(typeCache, typ) delete(typeCache, key)
return nil, err return nil, err
} }
*typeCache[typ] = *info *typeCache[key] = *info
return typeCache[typ], err return typeCache[key], err
}
type field struct {
index int
info *typeinfo
} }
func structFields(typ reflect.Type) (fields []field, err error) { func structFields(typ reflect.Type) (fields []field, err error) {
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
if f := typ.Field(i); f.PkgPath == "" { // exported if f := typ.Field(i); f.PkgPath == "" { // exported
info, err := cachedTypeInfo1(f.Type) tags := parseStructTag(f.Tag.Get("rlp"))
info, err := cachedTypeInfo1(f.Type, tags)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -65,9 +84,13 @@ func structFields(typ reflect.Type) (fields []field, err error) {
return fields, nil return fields, nil
} }
func genTypeInfo(typ reflect.Type) (info *typeinfo, err error) { func parseStructTag(tag string) tags {
return tags{nilOK: tag == "nil"}
}
func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) {
info = new(typeinfo) info = new(typeinfo)
if info.decoder, err = makeDecoder(typ); err != nil { if info.decoder, err = makeDecoder(typ, tags); err != nil {
return nil, err return nil, err
} }
if info.writer, err = makeWriter(typ); err != nil { if info.writer, err = makeWriter(typ); err != nil {