From 8deec2e45a315393025145ace43782b6294541d5 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 14 May 2019 15:09:56 +0200 Subject: [PATCH] rlp: fixes for two corner cases and documentation (#19527) These changes fix two corner cases related to internal handling of types in package rlp: The "tail" struct tag can only be applied to the last field. The check for this was wrong and didn't allow for private fields after the field with the tag. Unsupported types (e.g. structs containing int) which implement either the Encoder or Decoder interface but not both couldn't be encoded/decoded. Also fixes #19367 --- rlp/decode.go | 30 ++++++++++---------- rlp/decode_test.go | 32 +++++++++++++++++++++ rlp/encode.go | 28 ++++++++++--------- rlp/encode_test.go | 9 ++++++ rlp/typecache.go | 69 ++++++++++++++++++++++++++-------------------- 5 files changed, 111 insertions(+), 57 deletions(-) diff --git a/rlp/decode.go b/rlp/decode.go index e638c01ea..4f29f2fb0 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -115,15 +115,17 @@ type Decoder interface { // type, Decode will return an error. Decode also supports *big.Int. // There is no size limit for big integers. // +// To decode into a boolean, the input must contain an unsigned integer +// of value zero (false) or one (true). +// // To decode into an interface value, Decode stores one of these // in the value: // // []interface{}, for RLP lists // []byte, for RLP strings // -// Non-empty interface types are not supported, nor are booleans, -// signed integers, floating point numbers, maps, channels and -// functions. +// Non-empty interface types are not supported, nor are signed integers, +// floating point numbers, maps, channels and functions. // // Note that Decode does not set an input limit for all readers // and may be vulnerable to panics cause by huge value sizes. If @@ -306,9 +308,9 @@ func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) { } return decodeByteSlice, nil } - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := cachedTypeInfo1(etype, tags{}) + if etypeinfo.decoderErr != nil { + return nil, etypeinfo.decoderErr } var dec decoder switch { @@ -467,9 +469,9 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { // the pointer's element type. func makePtrDecoder(typ reflect.Type) (decoder, error) { etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := cachedTypeInfo1(etype, tags{}) + if etypeinfo.decoderErr != nil { + return nil, etypeinfo.decoderErr } dec := func(s *Stream, val reflect.Value) (err error) { newval := val @@ -491,9 +493,9 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) { // This decoder is used for pointer-typed struct fields with struct tag "nil". func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := cachedTypeInfo1(etype, tags{}) + if etypeinfo.decoderErr != nil { + return nil, etypeinfo.decoderErr } dec := func(s *Stream, val reflect.Value) (err error) { kind, size, err := s.Kind() @@ -814,12 +816,12 @@ func (s *Stream) Decode(val interface{}) error { if rval.IsNil() { return errDecodeIntoNil } - info, err := cachedTypeInfo(rtyp.Elem(), tags{}) + decoder, err := cachedDecoder(rtyp.Elem()) if err != nil { return err } - err = info.decoder(s, rval.Elem()) + err = decoder(s, rval.Elem()) if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 { // add decode target type to error so context has more meaning decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")")) diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 4d8abd001..fa57182c9 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -347,6 +347,12 @@ type tailUint struct { Tail []uint `rlp:"tail"` } +type tailPrivateFields struct { + A uint + Tail []uint `rlp:"tail"` + x, y bool +} + var ( veryBigInt = big.NewInt(0).Add( big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), @@ -510,6 +516,11 @@ var decodeTests = []decodeTest{ ptr: new(tailRaw), value: tailRaw{A: 1, Tail: []RawValue{}}, }, + { + input: "C3010203", + ptr: new(tailPrivateFields), + value: tailPrivateFields{A: 1, Tail: []uint{2, 3}}, + }, // struct tag "-" { @@ -691,6 +702,27 @@ func TestDecoderInByteSlice(t *testing.T) { } } +type unencodableDecoder func() + +func (f *unencodableDecoder) DecodeRLP(s *Stream) error { + if _, err := s.List(); err != nil { + return err + } + if err := s.ListEnd(); err != nil { + return err + } + *f = func() {} + return nil +} + +func TestDecoderFunc(t *testing.T) { + var x func() + if err := DecodeBytes([]byte{0xC0}, (*unencodableDecoder)(&x)); err != nil { + t.Fatal(err) + } + x() +} + func ExampleDecode() { input, _ := hex.DecodeString("C90A1486666F6F626172") diff --git a/rlp/encode.go b/rlp/encode.go index 445b4b5b2..f255c38a9 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -73,10 +73,12 @@ type Encoder interface { // An unsigned integer value is encoded as an RLP string. Zero always // encodes as an empty RLP string. Encode also supports *big.Int. // +// Boolean values are encoded as unsigned integers zero (false) and one (true). +// // An interface value encodes as the value contained in the interface. // -// Boolean values are not supported, nor are signed integers, floating -// point numbers, maps, channels and functions. +// Signed integers are not supported, nor are floating point numbers, maps, +// channels and functions. func Encode(w io.Writer, val interface{}) error { if outer, ok := w.(*encbuf); ok { // Encode was called by some type's EncodeRLP. @@ -180,11 +182,11 @@ func (w *encbuf) Write(b []byte) (int, error) { func (w *encbuf) encode(val interface{}) error { rval := reflect.ValueOf(val) - ti, err := cachedTypeInfo(rval.Type(), tags{}) + writer, err := cachedWriter(rval.Type()) if err != nil { return err } - return ti.writer(rval, w) + return writer(rval, w) } func (w *encbuf) encodeStringHeader(size int) { @@ -497,17 +499,17 @@ func writeInterface(val reflect.Value, w *encbuf) error { return nil } eval := val.Elem() - ti, err := cachedTypeInfo(eval.Type(), tags{}) + writer, err := cachedWriter(eval.Type()) if err != nil { return err } - return ti.writer(eval, w) + return writer(eval, w) } func makeSliceWriter(typ reflect.Type, ts tags) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err + etypeinfo := cachedTypeInfo1(typ.Elem(), tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr } writer := func(val reflect.Value, w *encbuf) error { if !ts.tail { @@ -543,9 +545,9 @@ func makeStructWriter(typ reflect.Type) (writer, error) { } func makePtrWriter(typ reflect.Type) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err + etypeinfo := cachedTypeInfo1(typ.Elem(), tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr } // determine nil pointer handler @@ -577,7 +579,7 @@ func makePtrWriter(typ reflect.Type) (writer, error) { } return etypeinfo.writer(val.Elem(), w) } - return writer, err + return writer, nil } // putint writes i to the beginning of b in big endian byte diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 827960f7c..6e49b89a8 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -49,6 +49,13 @@ func (e byteEncoder) EncodeRLP(w io.Writer) error { return nil } +type undecodableEncoder func() + +func (f undecodableEncoder) EncodeRLP(w io.Writer) error { + _, err := w.Write(EmptyList) + return err +} + type encodableReader struct { A, B uint } @@ -239,6 +246,8 @@ var encTests = []encTest{ {val: (*testEncoder)(nil), output: "00000000"}, {val: &testEncoder{}, output: "00010001000100010001"}, {val: &testEncoder{errors.New("test error")}, error: "test error"}, + // verify that the Encoder interface works for unsupported types like func(). + {val: undecodableEncoder(func() {}), output: "C0"}, // verify that pointer method testEncoder.EncodeRLP is called for // addressable non-pointer values. {val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"}, diff --git a/rlp/typecache.go b/rlp/typecache.go index 8c2dd518e..ab5ee3da7 100644 --- a/rlp/typecache.go +++ b/rlp/typecache.go @@ -29,8 +29,10 @@ var ( ) type typeinfo struct { - decoder - writer + decoder decoder + decoderErr error // error from makeDecoder + writer writer + writerErr error // error from makeWriter } // represents struct tags @@ -56,12 +58,22 @@ type decoder func(*Stream, reflect.Value) error type writer func(reflect.Value, *encbuf) error -func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) { +func cachedDecoder(typ reflect.Type) (decoder, error) { + info := cachedTypeInfo(typ, tags{}) + return info.decoder, info.decoderErr +} + +func cachedWriter(typ reflect.Type) (writer, error) { + info := cachedTypeInfo(typ, tags{}) + return info.writer, info.writerErr +} + +func cachedTypeInfo(typ reflect.Type, tags tags) *typeinfo { typeCacheMutex.RLock() info := typeCache[typekey{typ, tags}] typeCacheMutex.RUnlock() if info != nil { - return info, nil + return info } // not in the cache, need to generate info for this type. typeCacheMutex.Lock() @@ -69,25 +81,20 @@ func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) { return cachedTypeInfo1(typ, tags) } -func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) { +func cachedTypeInfo1(typ reflect.Type, tags tags) *typeinfo { key := typekey{typ, tags} info := typeCache[key] if info != nil { // another goroutine got the write lock first - return info, nil + return info } // put a dummy value into the cache before generating. // if the generator tries to lookup itself, it will get // the dummy value and won't call itself recursively. - typeCache[key] = new(typeinfo) - info, err := genTypeInfo(typ, tags) - if err != nil { - // remove the dummy value if the generator fails - delete(typeCache, key) - return nil, err - } - *typeCache[key] = *info - return typeCache[key], err + info = new(typeinfo) + typeCache[key] = info + info.generate(typ, tags) + return info } type field struct { @@ -96,26 +103,24 @@ type field struct { } func structFields(typ reflect.Type) (fields []field, err error) { + lastPublic := lastPublicField(typ) for i := 0; i < typ.NumField(); i++ { if f := typ.Field(i); f.PkgPath == "" { // exported - tags, err := parseStructTag(typ, i) + tags, err := parseStructTag(typ, i, lastPublic) if err != nil { return nil, err } if tags.ignored { continue } - info, err := cachedTypeInfo1(f.Type, tags) - if err != nil { - return nil, err - } + info := cachedTypeInfo1(f.Type, tags) fields = append(fields, field{i, info}) } } return fields, nil } -func parseStructTag(typ reflect.Type, fi int) (tags, error) { +func parseStructTag(typ reflect.Type, fi, lastPublic int) (tags, error) { f := typ.Field(fi) var ts tags for _, t := range strings.Split(f.Tag.Get("rlp"), ",") { @@ -127,7 +132,7 @@ func parseStructTag(typ reflect.Type, fi int) (tags, error) { ts.nilOK = true case "tail": ts.tail = true - if fi != typ.NumField()-1 { + if fi != lastPublic { return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name) } if f.Type.Kind() != reflect.Slice { @@ -140,15 +145,19 @@ func parseStructTag(typ reflect.Type, fi int) (tags, error) { return ts, nil } -func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) { - info = new(typeinfo) - if info.decoder, err = makeDecoder(typ, tags); err != nil { - return nil, err +func lastPublicField(typ reflect.Type) int { + last := 0 + for i := 0; i < typ.NumField(); i++ { + if typ.Field(i).PkgPath == "" { + last = i + } } - if info.writer, err = makeWriter(typ, tags); err != nil { - return nil, err - } - return info, nil + return last +} + +func (i *typeinfo) generate(typ reflect.Type, tags tags) { + i.decoder, i.decoderErr = makeDecoder(typ, tags) + i.writer, i.writerErr = makeWriter(typ, tags) } func isUint(k reflect.Kind) bool {