diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go index 51b8506bc..51a10f3f3 100644 --- a/core/types/derive_sha.go +++ b/core/types/derive_sha.go @@ -17,8 +17,6 @@ package types import ( - "bytes" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" ) @@ -37,26 +35,24 @@ type Hasher interface { func DeriveSha(list DerivableList, hasher Hasher) common.Hash { hasher.Reset() - keybuf := new(bytes.Buffer) // StackTrie requires values to be inserted in increasing // hash order, which is not the order that `list` provides // hashes in. This insertion sequence ensures that the // order is correct. + + var buf []byte for i := 1; i < list.Len() && i <= 0x7f; i++ { - keybuf.Reset() - rlp.Encode(keybuf, uint(i)) - hasher.Update(keybuf.Bytes(), list.GetRlp(i)) + buf = rlp.AppendUint64(buf[:0], uint64(i)) + hasher.Update(buf, list.GetRlp(i)) } if list.Len() > 0 { - keybuf.Reset() - rlp.Encode(keybuf, uint(0)) - hasher.Update(keybuf.Bytes(), list.GetRlp(0)) + buf = rlp.AppendUint64(buf[:0], 0) + hasher.Update(buf, list.GetRlp(0)) } for i := 0x80; i < list.Len(); i++ { - keybuf.Reset() - rlp.Encode(keybuf, uint(i)) - hasher.Update(keybuf.Bytes(), list.GetRlp(i)) + buf = rlp.AppendUint64(buf[:0], uint64(i)) + hasher.Update(buf, list.GetRlp(i)) } return hasher.Hash() } diff --git a/rlp/raw.go b/rlp/raw.go index c2a8517f6..3071e99ca 100644 --- a/rlp/raw.go +++ b/rlp/raw.go @@ -180,3 +180,74 @@ func readSize(b []byte, slen byte) (uint64, error) { } return s, nil } + +// AppendUint64 appends the RLP encoding of i to b, and returns the resulting slice. +func AppendUint64(b []byte, i uint64) []byte { + if i == 0 { + return append(b, 0x80) + } else if i < 128 { + return append(b, byte(i)) + } + switch { + case i < (1 << 8): + return append(b, 0x81, byte(i)) + case i < (1 << 16): + return append(b, 0x82, + byte(i>>8), + byte(i), + ) + case i < (1 << 24): + return append(b, 0x83, + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 32): + return append(b, 0x84, + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 40): + return append(b, 0x85, + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + case i < (1 << 48): + return append(b, 0x86, + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 56): + return append(b, 0x87, + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + default: + return append(b, 0x88, + byte(i>>56), + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + } +} diff --git a/rlp/raw_test.go b/rlp/raw_test.go index cdae4ff08..c976c4f73 100644 --- a/rlp/raw_test.go +++ b/rlp/raw_test.go @@ -21,6 +21,7 @@ import ( "io" "reflect" "testing" + "testing/quick" ) func TestCountValues(t *testing.T) { @@ -239,3 +240,40 @@ func TestReadSize(t *testing.T) { } } } + +func TestAppendUint64(t *testing.T) { + tests := []struct { + input uint64 + slice []byte + output string + }{ + {0, nil, "80"}, + {1, nil, "01"}, + {2, nil, "02"}, + {127, nil, "7F"}, + {128, nil, "8180"}, + {129, nil, "8181"}, + {0xFFFFFF, nil, "83FFFFFF"}, + {127, []byte{1, 2, 3}, "0102037F"}, + {0xFFFFFF, []byte{1, 2, 3}, "01020383FFFFFF"}, + } + + for _, test := range tests { + x := AppendUint64(test.slice, test.input) + if !bytes.Equal(x, unhex(test.output)) { + t.Errorf("AppendUint64(%v, %d): got %x, want %s", test.slice, test.input, x, test.output) + } + } +} + +func TestAppendUint64Random(t *testing.T) { + fn := func(i uint64) bool { + enc, _ := EncodeToBytes(i) + encAppend := AppendUint64(nil, i) + return bytes.Equal(enc, encAppend) + } + config := quick.Config{MaxCountScale: 50} + if err := quick.Check(fn, &config); err != nil { + t.Fatal(err) + } +}