rlp: improve decoder stream implementation (#22858)

This commit makes various cleanup changes to rlp.Stream.

* rlp: shrink Stream struct

This removes a lot of unused padding space in Stream by reordering the
fields. The size of Stream changes from 120 bytes to 88 bytes. Stream
instances are internally cached and reused using sync.Pool, so this does
not improve performance.

* rlp: simplify list stack

The list stack kept track of the size of the current list context as
well as the current offset into it. The size had to be stored in the
stack in order to subtract it from the remaining bytes of any enclosing
list in ListEnd. It seems that this can be implemented in a simpler
way: just subtract the size from the enclosing list context in List instead.
This commit is contained in:
Felix Lange 2021-05-18 12:10:27 +02:00 committed by GitHub
parent 3e6f46caec
commit 088da24ebf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -518,7 +518,7 @@ func decodeDecoder(s *Stream, val reflect.Value) error {
} }
// Kind represents the kind of value contained in an RLP stream. // Kind represents the kind of value contained in an RLP stream.
type Kind int type Kind int8
const ( const (
Byte Kind = iota Byte Kind = iota
@ -561,22 +561,16 @@ type ByteReader interface {
type Stream struct { type Stream struct {
r ByteReader r ByteReader
// number of bytes remaining to be read from r. remaining uint64 // number of bytes remaining to be read from r
remaining uint64
limited bool
// auxiliary buffer for integer decoding
uintbuf []byte
kind Kind // kind of value ahead
size uint64 // size of value ahead size uint64 // size of value ahead
byteval byte // value of single byte in type tag
kinderr error // error from last readKind kinderr error // error from last readKind
stack []listpos stack []uint64 // list sizes
uintbuf [8]byte // auxiliary buffer for integer decoding
kind Kind // kind of value ahead
byteval byte // value of single byte in type tag
limited bool // true if input limit is in effect
} }
type listpos struct{ pos, size uint64 }
// NewStream creates a new decoding stream reading from r. // NewStream creates a new decoding stream reading from r.
// //
// If r implements the ByteReader interface, Stream will // If r implements the ByteReader interface, Stream will
@ -646,8 +640,8 @@ func (s *Stream) Raw() ([]byte, error) {
s.kind = -1 // rearm Kind s.kind = -1 // rearm Kind
return []byte{s.byteval}, nil return []byte{s.byteval}, nil
} }
// the original header has already been read and is no longer // The original header has already been read and is no longer
// available. read content and put a new header in front of it. // available. Read content and put a new header in front of it.
start := headsize(size) start := headsize(size)
buf := make([]byte, uint64(start)+size) buf := make([]byte, uint64(start)+size)
if err := s.readFull(buf[start:]); err != nil { if err := s.readFull(buf[start:]); err != nil {
@ -730,7 +724,14 @@ func (s *Stream) List() (size uint64, err error) {
if kind != List { if kind != List {
return 0, ErrExpectedList return 0, ErrExpectedList
} }
s.stack = append(s.stack, listpos{0, size})
// Remove size of inner list from outer list before pushing the new size
// onto the stack. This ensures that the remaining outer list size will
// be correct after the matching call to ListEnd.
if inList, limit := s.listLimit(); inList {
s.stack[len(s.stack)-1] = limit - size
}
s.stack = append(s.stack, size)
s.kind = -1 s.kind = -1
s.size = 0 s.size = 0
return size, nil return size, nil
@ -739,17 +740,13 @@ func (s *Stream) List() (size uint64, err error) {
// ListEnd returns to the enclosing list. // ListEnd returns to the enclosing list.
// The input reader must be positioned at the end of a list. // The input reader must be positioned at the end of a list.
func (s *Stream) ListEnd() error { func (s *Stream) ListEnd() error {
if len(s.stack) == 0 { // Ensure that no more data is remaining in the current list.
if inList, listLimit := s.listLimit(); !inList {
return errNotInList return errNotInList
} } else if listLimit > 0 {
tos := s.stack[len(s.stack)-1]
if tos.pos != tos.size {
return errNotAtEOL return errNotAtEOL
} }
s.stack = s.stack[:len(s.stack)-1] // pop s.stack = s.stack[:len(s.stack)-1] // pop
if len(s.stack) > 0 {
s.stack[len(s.stack)-1].pos += tos.size
}
s.kind = -1 s.kind = -1
s.size = 0 s.size = 0
return nil return nil
@ -777,7 +774,7 @@ func (s *Stream) Decode(val interface{}) error {
err = decoder(s, rval.Elem()) err = decoder(s, rval.Elem())
if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 { if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 {
// add decode target type to error so context has more meaning // Add decode target type to error so context has more meaning.
decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")")) decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")"))
} }
return err return err
@ -800,6 +797,9 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
case *bytes.Reader: case *bytes.Reader:
s.remaining = uint64(br.Len()) s.remaining = uint64(br.Len())
s.limited = true s.limited = true
case *bytes.Buffer:
s.remaining = uint64(br.Len())
s.limited = true
case *strings.Reader: case *strings.Reader:
s.remaining = uint64(br.Len()) s.remaining = uint64(br.Len())
s.limited = true s.limited = true
@ -818,10 +818,8 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
s.size = 0 s.size = 0
s.kind = -1 s.kind = -1
s.kinderr = nil s.kinderr = nil
if s.uintbuf == nil {
s.uintbuf = make([]byte, 8)
}
s.byteval = 0 s.byteval = 0
s.uintbuf = [8]byte{}
} }
// Kind returns the kind and size of the next value in the // Kind returns the kind and size of the next value in the
@ -836,35 +834,29 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
// the value. Subsequent calls to Kind (until the value is decoded) // the value. Subsequent calls to Kind (until the value is decoded)
// will not advance the input reader and return cached information. // will not advance the input reader and return cached information.
func (s *Stream) Kind() (kind Kind, size uint64, err error) { func (s *Stream) Kind() (kind Kind, size uint64, err error) {
var tos *listpos if s.kind >= 0 {
if len(s.stack) > 0 { return s.kind, s.size, s.kinderr
tos = &s.stack[len(s.stack)-1]
} }
if s.kind < 0 {
s.kinderr = nil // Check for end of list. This needs to be done here because readKind
// Don't read further if we're at the end of the // checks against the list size, and would return the wrong error.
// innermost list. inList, listLimit := s.listLimit()
if tos != nil && tos.pos == tos.size { if inList && listLimit == 0 {
return 0, 0, EOL return 0, 0, EOL
} }
// Read the actual size tag.
s.kind, s.size, s.kinderr = s.readKind() s.kind, s.size, s.kinderr = s.readKind()
if s.kinderr == nil { if s.kinderr == nil {
if tos == nil { // Check the data size of the value ahead against input limits. This
// At toplevel, check that the value is smaller // is done here because many decoders require allocating an input
// than the remaining input length. // buffer matching the value size. Checking it here protects those
if s.limited && s.size > s.remaining { // decoders from inputs declaring very large value size.
if inList && s.size > listLimit {
s.kinderr = ErrElemTooLarge
} else if s.limited && s.size > s.remaining {
s.kinderr = ErrValueTooLarge s.kinderr = ErrValueTooLarge
} }
} else {
// Inside a list, check that the value doesn't overflow the list.
if s.size > tos.size-tos.pos {
s.kinderr = ErrElemTooLarge
} }
}
}
}
// Note: this might return a sticky error generated
// by an earlier call to readKind.
return s.kind, s.size, s.kinderr return s.kind, s.size, s.kinderr
} }
@ -891,37 +883,35 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) {
s.byteval = b s.byteval = b
return Byte, 0, nil return Byte, 0, nil
case b < 0xB8: case b < 0xB8:
// Otherwise, if a string is 0-55 bytes long, // Otherwise, if a string is 0-55 bytes long, the RLP encoding consists
// the RLP encoding consists of a single byte with value 0x80 plus the // of a single byte with value 0x80 plus the length of the string
// length of the string followed by the string. The range of the first // followed by the string. The range of the first byte is thus [0x80, 0xB7].
// byte is thus [0x80, 0xB7].
return String, uint64(b - 0x80), nil return String, uint64(b - 0x80), nil
case b < 0xC0: case b < 0xC0:
// If a string is more than 55 bytes long, the // If a string is more than 55 bytes long, the RLP encoding consists of a
// RLP encoding consists of a single byte with value 0xB7 plus the length // single byte with value 0xB7 plus the length of the length of the
// of the length of the string in binary form, followed by the length of // string in binary form, followed by the length of the string, followed
// the string, followed by the string. For example, a length-1024 string // by the string. For example, a length-1024 string would be encoded as
// would be encoded as 0xB90400 followed by the string. The range of // 0xB90400 followed by the string. The range of the first byte is thus
// the first byte is thus [0xB8, 0xBF]. // [0xB8, 0xBF].
size, err = s.readUint(b - 0xB7) size, err = s.readUint(b - 0xB7)
if err == nil && size < 56 { if err == nil && size < 56 {
err = ErrCanonSize err = ErrCanonSize
} }
return String, size, err return String, size, err
case b < 0xF8: case b < 0xF8:
// If the total payload of a list // If the total payload of a list (i.e. the combined length of all its
// (i.e. the combined length of all its items) is 0-55 bytes long, the // items) is 0-55 bytes long, the RLP encoding consists of a single byte
// RLP encoding consists of a single byte with value 0xC0 plus the length // with value 0xC0 plus the length of the list followed by the
// of the list followed by the concatenation of the RLP encodings of the // concatenation of the RLP encodings of the items. The range of the
// items. The range of the first byte is thus [0xC0, 0xF7]. // first byte is thus [0xC0, 0xF7].
return List, uint64(b - 0xC0), nil return List, uint64(b - 0xC0), nil
default: default:
// If the total payload of a list is more than 55 bytes long, // If the total payload of a list is more than 55 bytes long, the RLP
// the RLP encoding consists of a single byte with value 0xF7 // encoding consists of a single byte with value 0xF7 plus the length of
// plus the length of the length of the payload in binary // the length of the payload in binary form, followed by the length of
// form, followed by the length of the payload, followed by // the payload, followed by the concatenation of the RLP encodings of
// the concatenation of the RLP encodings of the items. The // the items. The range of the first byte is thus [0xF8, 0xFF].
// range of the first byte is thus [0xF8, 0xFF].
size, err = s.readUint(b - 0xF7) size, err = s.readUint(b - 0xF7)
if err == nil && size < 56 { if err == nil && size < 56 {
err = ErrCanonSize err = ErrCanonSize
@ -940,22 +930,20 @@ func (s *Stream) readUint(size byte) (uint64, error) {
return uint64(b), err return uint64(b), err
default: default:
start := int(8 - size) start := int(8 - size)
for i := 0; i < start; i++ { s.uintbuf = [8]byte{}
s.uintbuf[i] = 0
}
if err := s.readFull(s.uintbuf[start:]); err != nil { if err := s.readFull(s.uintbuf[start:]); err != nil {
return 0, err return 0, err
} }
if s.uintbuf[start] == 0 { if s.uintbuf[start] == 0 {
// Note: readUint is also used to decode integer // Note: readUint is also used to decode integer values.
// values. The error needs to be adjusted to become // The error needs to be adjusted to become ErrCanonInt in this case.
// ErrCanonInt in this case.
return 0, ErrCanonSize return 0, ErrCanonSize
} }
return binary.BigEndian.Uint64(s.uintbuf), nil return binary.BigEndian.Uint64(s.uintbuf[:]), nil
} }
} }
// readFull reads into buf from the underlying stream.
func (s *Stream) readFull(buf []byte) (err error) { func (s *Stream) readFull(buf []byte) (err error) {
if err := s.willRead(uint64(len(buf))); err != nil { if err := s.willRead(uint64(len(buf))); err != nil {
return err return err
@ -977,6 +965,7 @@ func (s *Stream) readFull(buf []byte) (err error) {
return err return err
} }
// readByte reads a single byte from the underlying stream.
func (s *Stream) readByte() (byte, error) { func (s *Stream) readByte() (byte, error) {
if err := s.willRead(1); err != nil { if err := s.willRead(1); err != nil {
return 0, err return 0, err
@ -988,16 +977,16 @@ func (s *Stream) readByte() (byte, error) {
return b, err return b, err
} }
// willRead is called before any read from the underlying stream. It checks
// n against size limits, and updates the limits if n doesn't overflow them.
func (s *Stream) willRead(n uint64) error { func (s *Stream) willRead(n uint64) error {
s.kind = -1 // rearm Kind s.kind = -1 // rearm Kind
if len(s.stack) > 0 { if inList, limit := s.listLimit(); inList {
// check list overflow if n > limit {
tos := s.stack[len(s.stack)-1]
if n > tos.size-tos.pos {
return ErrElemTooLarge return ErrElemTooLarge
} }
s.stack[len(s.stack)-1].pos += n s.stack[len(s.stack)-1] = limit - n
} }
if s.limited { if s.limited {
if n > s.remaining { if n > s.remaining {
@ -1007,3 +996,11 @@ func (s *Stream) willRead(n uint64) error {
} }
return nil return nil
} }
// listLimit returns the amount of data remaining in the innermost list.
func (s *Stream) listLimit() (inList bool, limit uint64) {
if len(s.stack) == 0 {
return false, 0
}
return true, s.stack[len(s.stack)-1]
}