rlp: fixes for two corner cases and documentation (#19527)

These changes fix two corner cases related to internal handling of types
in package rlp: The "tail" struct tag can only be applied to the last field.
The check for this was wrong and didn't allow for private fields after the
field with the tag. Unsupported types (e.g. structs containing int) which
implement either the Encoder or Decoder interface but not both 
couldn't be encoded/decoded.

Also fixes #19367
This commit is contained in:
Felix Lange 2019-05-14 15:09:56 +02:00 committed by GitHub
parent 184af72e4e
commit 8deec2e45a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 111 additions and 57 deletions

View File

@ -115,15 +115,17 @@ type Decoder interface {
// type, Decode will return an error. Decode also supports *big.Int. // type, Decode will return an error. Decode also supports *big.Int.
// There is no size limit for big integers. // There is no size limit for big integers.
// //
// To decode into a boolean, the input must contain an unsigned integer
// of value zero (false) or one (true).
//
// To decode into an interface value, Decode stores one of these // To decode into an interface value, Decode stores one of these
// in the value: // in the value:
// //
// []interface{}, for RLP lists // []interface{}, for RLP lists
// []byte, for RLP strings // []byte, for RLP strings
// //
// Non-empty interface types are not supported, nor are booleans, // Non-empty interface types are not supported, nor are signed integers,
// signed integers, floating point numbers, maps, channels and // floating point numbers, maps, channels and functions.
// functions.
// //
// Note that Decode does not set an input limit for all readers // Note that Decode does not set an input limit for all readers
// and may be vulnerable to panics cause by huge value sizes. If // and may be vulnerable to panics cause by huge value sizes. If
@ -306,9 +308,9 @@ func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) {
} }
return decodeByteSlice, nil return decodeByteSlice, nil
} }
etypeinfo, err := cachedTypeInfo1(etype, tags{}) etypeinfo := cachedTypeInfo1(etype, tags{})
if err != nil { if etypeinfo.decoderErr != nil {
return nil, err return nil, etypeinfo.decoderErr
} }
var dec decoder var dec decoder
switch { switch {
@ -467,9 +469,9 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
// the pointer's element type. // the pointer's element type.
func makePtrDecoder(typ reflect.Type) (decoder, error) { func makePtrDecoder(typ reflect.Type) (decoder, error) {
etype := typ.Elem() etype := typ.Elem()
etypeinfo, err := cachedTypeInfo1(etype, tags{}) etypeinfo := cachedTypeInfo1(etype, tags{})
if err != nil { if etypeinfo.decoderErr != nil {
return nil, err return nil, etypeinfo.decoderErr
} }
dec := func(s *Stream, val reflect.Value) (err error) { dec := func(s *Stream, val reflect.Value) (err error) {
newval := val newval := val
@ -491,9 +493,9 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) {
// This decoder is used for pointer-typed struct fields with struct tag "nil". // This decoder is used for pointer-typed struct fields with struct tag "nil".
func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) {
etype := typ.Elem() etype := typ.Elem()
etypeinfo, err := cachedTypeInfo1(etype, tags{}) etypeinfo := cachedTypeInfo1(etype, tags{})
if err != nil { if etypeinfo.decoderErr != nil {
return nil, err return nil, etypeinfo.decoderErr
} }
dec := func(s *Stream, val reflect.Value) (err error) { dec := func(s *Stream, val reflect.Value) (err error) {
kind, size, err := s.Kind() kind, size, err := s.Kind()
@ -814,12 +816,12 @@ func (s *Stream) Decode(val interface{}) error {
if rval.IsNil() { if rval.IsNil() {
return errDecodeIntoNil return errDecodeIntoNil
} }
info, err := cachedTypeInfo(rtyp.Elem(), tags{}) decoder, err := cachedDecoder(rtyp.Elem())
if err != nil { if err != nil {
return err return err
} }
err = info.decoder(s, rval.Elem()) err = decoder(s, rval.Elem())
if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 { if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 {
// add decode target type to error so context has more meaning // add decode target type to error so context has more meaning
decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")")) decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")"))

View File

@ -347,6 +347,12 @@ type tailUint struct {
Tail []uint `rlp:"tail"` Tail []uint `rlp:"tail"`
} }
type tailPrivateFields struct {
A uint
Tail []uint `rlp:"tail"`
x, y bool
}
var ( var (
veryBigInt = big.NewInt(0).Add( veryBigInt = big.NewInt(0).Add(
big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16),
@ -510,6 +516,11 @@ var decodeTests = []decodeTest{
ptr: new(tailRaw), ptr: new(tailRaw),
value: tailRaw{A: 1, Tail: []RawValue{}}, value: tailRaw{A: 1, Tail: []RawValue{}},
}, },
{
input: "C3010203",
ptr: new(tailPrivateFields),
value: tailPrivateFields{A: 1, Tail: []uint{2, 3}},
},
// struct tag "-" // struct tag "-"
{ {
@ -691,6 +702,27 @@ func TestDecoderInByteSlice(t *testing.T) {
} }
} }
type unencodableDecoder func()
func (f *unencodableDecoder) DecodeRLP(s *Stream) error {
if _, err := s.List(); err != nil {
return err
}
if err := s.ListEnd(); err != nil {
return err
}
*f = func() {}
return nil
}
func TestDecoderFunc(t *testing.T) {
var x func()
if err := DecodeBytes([]byte{0xC0}, (*unencodableDecoder)(&x)); err != nil {
t.Fatal(err)
}
x()
}
func ExampleDecode() { func ExampleDecode() {
input, _ := hex.DecodeString("C90A1486666F6F626172") input, _ := hex.DecodeString("C90A1486666F6F626172")

View File

@ -73,10 +73,12 @@ type Encoder interface {
// An unsigned integer value is encoded as an RLP string. Zero always // An unsigned integer value is encoded as an RLP string. Zero always
// encodes as an empty RLP string. Encode also supports *big.Int. // encodes as an empty RLP string. Encode also supports *big.Int.
// //
// Boolean values are encoded as unsigned integers zero (false) and one (true).
//
// An interface value encodes as the value contained in the interface. // An interface value encodes as the value contained in the interface.
// //
// Boolean values are not supported, nor are signed integers, floating // Signed integers are not supported, nor are floating point numbers, maps,
// point numbers, maps, channels and functions. // channels and functions.
func Encode(w io.Writer, val interface{}) error { func Encode(w io.Writer, val interface{}) error {
if outer, ok := w.(*encbuf); ok { if outer, ok := w.(*encbuf); ok {
// Encode was called by some type's EncodeRLP. // Encode was called by some type's EncodeRLP.
@ -180,11 +182,11 @@ func (w *encbuf) Write(b []byte) (int, error) {
func (w *encbuf) encode(val interface{}) error { func (w *encbuf) encode(val interface{}) error {
rval := reflect.ValueOf(val) rval := reflect.ValueOf(val)
ti, err := cachedTypeInfo(rval.Type(), tags{}) writer, err := cachedWriter(rval.Type())
if err != nil { if err != nil {
return err return err
} }
return ti.writer(rval, w) return writer(rval, w)
} }
func (w *encbuf) encodeStringHeader(size int) { func (w *encbuf) encodeStringHeader(size int) {
@ -497,17 +499,17 @@ func writeInterface(val reflect.Value, w *encbuf) error {
return nil return nil
} }
eval := val.Elem() eval := val.Elem()
ti, err := cachedTypeInfo(eval.Type(), tags{}) writer, err := cachedWriter(eval.Type())
if err != nil { if err != nil {
return err return err
} }
return ti.writer(eval, w) return writer(eval, w)
} }
func makeSliceWriter(typ reflect.Type, ts tags) (writer, error) { func makeSliceWriter(typ reflect.Type, ts tags) (writer, error) {
etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) etypeinfo := cachedTypeInfo1(typ.Elem(), tags{})
if err != nil { if etypeinfo.writerErr != nil {
return nil, err return nil, etypeinfo.writerErr
} }
writer := func(val reflect.Value, w *encbuf) error { writer := func(val reflect.Value, w *encbuf) error {
if !ts.tail { if !ts.tail {
@ -543,9 +545,9 @@ func makeStructWriter(typ reflect.Type) (writer, error) {
} }
func makePtrWriter(typ reflect.Type) (writer, error) { func makePtrWriter(typ reflect.Type) (writer, error) {
etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) etypeinfo := cachedTypeInfo1(typ.Elem(), tags{})
if err != nil { if etypeinfo.writerErr != nil {
return nil, err return nil, etypeinfo.writerErr
} }
// determine nil pointer handler // determine nil pointer handler
@ -577,7 +579,7 @@ func makePtrWriter(typ reflect.Type) (writer, error) {
} }
return etypeinfo.writer(val.Elem(), w) return etypeinfo.writer(val.Elem(), w)
} }
return writer, err return writer, nil
} }
// putint writes i to the beginning of b in big endian byte // putint writes i to the beginning of b in big endian byte

View File

@ -49,6 +49,13 @@ func (e byteEncoder) EncodeRLP(w io.Writer) error {
return nil return nil
} }
type undecodableEncoder func()
func (f undecodableEncoder) EncodeRLP(w io.Writer) error {
_, err := w.Write(EmptyList)
return err
}
type encodableReader struct { type encodableReader struct {
A, B uint A, B uint
} }
@ -239,6 +246,8 @@ var encTests = []encTest{
{val: (*testEncoder)(nil), output: "00000000"}, {val: (*testEncoder)(nil), output: "00000000"},
{val: &testEncoder{}, output: "00010001000100010001"}, {val: &testEncoder{}, output: "00010001000100010001"},
{val: &testEncoder{errors.New("test error")}, error: "test error"}, {val: &testEncoder{errors.New("test error")}, error: "test error"},
// verify that the Encoder interface works for unsupported types like func().
{val: undecodableEncoder(func() {}), output: "C0"},
// verify that pointer method testEncoder.EncodeRLP is called for // verify that pointer method testEncoder.EncodeRLP is called for
// addressable non-pointer values. // addressable non-pointer values.
{val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"}, {val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"},

View File

@ -29,8 +29,10 @@ var (
) )
type typeinfo struct { type typeinfo struct {
decoder decoder decoder
writer decoderErr error // error from makeDecoder
writer writer
writerErr error // error from makeWriter
} }
// represents struct tags // represents struct tags
@ -56,12 +58,22 @@ type decoder func(*Stream, reflect.Value) error
type writer func(reflect.Value, *encbuf) error type writer func(reflect.Value, *encbuf) error
func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) { func cachedDecoder(typ reflect.Type) (decoder, error) {
info := cachedTypeInfo(typ, tags{})
return info.decoder, info.decoderErr
}
func cachedWriter(typ reflect.Type) (writer, error) {
info := cachedTypeInfo(typ, tags{})
return info.writer, info.writerErr
}
func cachedTypeInfo(typ reflect.Type, tags tags) *typeinfo {
typeCacheMutex.RLock() typeCacheMutex.RLock()
info := typeCache[typekey{typ, tags}] info := typeCache[typekey{typ, tags}]
typeCacheMutex.RUnlock() typeCacheMutex.RUnlock()
if info != nil { if info != nil {
return info, nil return info
} }
// not in the cache, need to generate info for this type. // not in the cache, need to generate info for this type.
typeCacheMutex.Lock() typeCacheMutex.Lock()
@ -69,25 +81,20 @@ func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) {
return cachedTypeInfo1(typ, tags) return cachedTypeInfo1(typ, tags)
} }
func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) { func cachedTypeInfo1(typ reflect.Type, tags tags) *typeinfo {
key := typekey{typ, tags} key := typekey{typ, tags}
info := typeCache[key] info := typeCache[key]
if info != nil { if info != nil {
// another goroutine got the write lock first // another goroutine got the write lock first
return info, nil return info
} }
// put a dummy value into the cache before generating. // put a dummy value into the cache before generating.
// if the generator tries to lookup itself, it will get // if the generator tries to lookup itself, it will get
// the dummy value and won't call itself recursively. // the dummy value and won't call itself recursively.
typeCache[key] = new(typeinfo) info = new(typeinfo)
info, err := genTypeInfo(typ, tags) typeCache[key] = info
if err != nil { info.generate(typ, tags)
// remove the dummy value if the generator fails return info
delete(typeCache, key)
return nil, err
}
*typeCache[key] = *info
return typeCache[key], err
} }
type field struct { type field struct {
@ -96,26 +103,24 @@ type field struct {
} }
func structFields(typ reflect.Type) (fields []field, err error) { func structFields(typ reflect.Type) (fields []field, err error) {
lastPublic := lastPublicField(typ)
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
if f := typ.Field(i); f.PkgPath == "" { // exported if f := typ.Field(i); f.PkgPath == "" { // exported
tags, err := parseStructTag(typ, i) tags, err := parseStructTag(typ, i, lastPublic)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if tags.ignored { if tags.ignored {
continue continue
} }
info, err := cachedTypeInfo1(f.Type, tags) info := cachedTypeInfo1(f.Type, tags)
if err != nil {
return nil, err
}
fields = append(fields, field{i, info}) fields = append(fields, field{i, info})
} }
} }
return fields, nil return fields, nil
} }
func parseStructTag(typ reflect.Type, fi int) (tags, error) { func parseStructTag(typ reflect.Type, fi, lastPublic int) (tags, error) {
f := typ.Field(fi) f := typ.Field(fi)
var ts tags var ts tags
for _, t := range strings.Split(f.Tag.Get("rlp"), ",") { for _, t := range strings.Split(f.Tag.Get("rlp"), ",") {
@ -127,7 +132,7 @@ func parseStructTag(typ reflect.Type, fi int) (tags, error) {
ts.nilOK = true ts.nilOK = true
case "tail": case "tail":
ts.tail = true ts.tail = true
if fi != typ.NumField()-1 { if fi != lastPublic {
return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name) return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name)
} }
if f.Type.Kind() != reflect.Slice { if f.Type.Kind() != reflect.Slice {
@ -140,15 +145,19 @@ func parseStructTag(typ reflect.Type, fi int) (tags, error) {
return ts, nil return ts, nil
} }
func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) { func lastPublicField(typ reflect.Type) int {
info = new(typeinfo) last := 0
if info.decoder, err = makeDecoder(typ, tags); err != nil { for i := 0; i < typ.NumField(); i++ {
return nil, err if typ.Field(i).PkgPath == "" {
last = i
}
} }
if info.writer, err = makeWriter(typ, tags); err != nil { return last
return nil, err }
}
return info, nil func (i *typeinfo) generate(typ reflect.Type, tags tags) {
i.decoder, i.decoderErr = makeDecoder(typ, tags)
i.writer, i.writerErr = makeWriter(typ, tags)
} }
func isUint(k reflect.Kind) bool { func isUint(k reflect.Kind) bool {