diff --git a/rlp/decode.go b/rlp/decode.go index c0b5f0699..1381f5274 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -183,6 +183,8 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { return decodeBigIntNoPtr, nil case isUint(kind): return decodeUint, nil + case kind == reflect.Bool: + return decodeBool, nil case kind == reflect.String: return decodeString, nil case kind == reflect.Slice || kind == reflect.Array: @@ -211,6 +213,15 @@ func decodeUint(s *Stream, val reflect.Value) error { return nil } +func decodeBool(s *Stream, val reflect.Value) error { + b, err := s.Bool() + if err != nil { + return wrapStreamError(err, val.Type()) + } + val.SetBool(b) + return nil +} + func decodeString(s *Stream, val reflect.Value) error { b, err := s.Bytes() if err != nil { @@ -697,6 +708,24 @@ func (s *Stream) uint(maxbits int) (uint64, error) { } } +// Bool reads an RLP string of up to 1 byte and returns its contents +// as an boolean. If the input does not contain an RLP string, the +// returned error will be ErrExpectedString. +func (s *Stream) Bool() (bool, error) { + num, err := s.uint(8) + if err != nil { + return false, err + } + switch num { + case 0: + return false, nil + case 1: + return true, nil + default: + return false, fmt.Errorf("rlp: invalid boolean value: %d", num) + } +} + // List starts decoding an RLP list. If the input does not contain a // list, the returned error will be ErrExpectedList. When the list's // end has been reached, any Stream operation will return EOL. diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 331faa9d8..d6b0611dd 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -19,6 +19,7 @@ package rlp import ( "bytes" "encoding/hex" + "errors" "fmt" "io" "math/big" @@ -116,6 +117,9 @@ func TestStreamErrors(t *testing.T) { {"817F", calls{"Uint"}, nil, ErrCanonSize}, {"8180", calls{"Uint"}, nil, nil}, + // Non-valid boolean + {"02", calls{"Bool"}, nil, errors.New("rlp: invalid boolean value: 2")}, + // Size tags must use the smallest possible encoding. // Leading zero bytes in the size tag are also rejected. {"8100", calls{"Uint"}, nil, ErrCanonSize}, @@ -315,6 +319,11 @@ var ( ) var decodeTests = []decodeTest{ + // booleans + {input: "01", ptr: new(bool), value: true}, + {input: "80", ptr: new(bool), value: false}, + {input: "02", ptr: new(bool), error: "rlp: invalid boolean value: 2"}, + // integers {input: "05", ptr: new(uint32), value: uint32(5)}, {input: "80", ptr: new(uint32), value: uint32(0)}, diff --git a/rlp/encode.go b/rlp/encode.go index 0ddef7bbb..b525ae4e7 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -361,6 +361,8 @@ func makeWriter(typ reflect.Type) (writer, error) { return writeBigIntNoPtr, nil case isUint(kind): return writeUint, nil + case kind == reflect.Bool: + return writeBool, nil case kind == reflect.String: return writeString, nil case kind == reflect.Slice && isByte(typ.Elem()): @@ -398,6 +400,15 @@ func writeUint(val reflect.Value, w *encbuf) error { return nil } +func writeBool(val reflect.Value, w *encbuf) error { + if val.Bool() { + w.str = append(w.str, 0x01) + } else { + w.str = append(w.str, 0x80) + } + return nil +} + func writeBigIntPtr(val reflect.Value, w *encbuf) error { ptr := val.Interface().(*big.Int) if ptr == nil { diff --git a/rlp/encode_test.go b/rlp/encode_test.go index e83c76b51..60bd95692 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -71,6 +71,10 @@ type encTest struct { } var encTests = []encTest{ + // booleans + {val: true, output: "01"}, + {val: false, output: "80"}, + // integers {val: uint32(0), output: "80"}, {val: uint32(127), output: "7F"},