diff --git a/statediff/helpers.go b/statediff/helpers.go new file mode 100644 index 000000000..976d21fe2 --- /dev/null +++ b/statediff/helpers.go @@ -0,0 +1,114 @@ +package statediff + +import ( + "sort" + "strings" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/trie" +) + +func sortKeys(data map[common.Address]*state.Account) []string { + var keys []string + for key := range data { + keys = append(keys, key.Hex()) + } + sort.Strings(keys) + + return keys +} + +func findIntersection(a, b []string) []string { + lenA := len(a) + lenB := len(b) + iOfA, iOfB := 0, 0 + updates := make([]string, 0) + if iOfA >= lenA || iOfB >= lenB { + return updates + } + for { + switch strings.Compare(a[iOfA], b[iOfB]) { + // a[iOfA] < b[iOfB] + case -1: + iOfA++ + if iOfA >= lenA { + return updates + } + break + // a[iOfA] == b[iOfB] + case 0: + updates = append(updates, a[iOfA]) + iOfA++ + iOfB++ + if iOfA >= lenA || iOfB >= lenB { + return updates + } + break + // a[iOfA] > b[iOfB] + case 1: + iOfB++ + if iOfB >= lenB { + return updates + } + break + } + } + +} + +func pathToStr(it trie.NodeIterator) string { + path := it.Path() + if hasTerm(path) { + path = path[:len(path)-1] + } + nibblePath := "" + for i, v := range common.ToHex(path) { + if i%2 == 0 && i > 1 { + continue + } + nibblePath = nibblePath + string(v) + } + + return nibblePath +} + +// Duplicated from trie/encoding.go +func hexToKeybytes(hex []byte) []byte { + if hasTerm(hex) { + hex = hex[:len(hex)-1] + } + if len(hex)&1 != 0 { + panic("can't convert hex key of odd length") + } + key := make([]byte, (len(hex)+1)/2) + decodeNibbles(hex, key) + + return key +} + +func decodeNibbles(nibbles []byte, bytes []byte) { + for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 { + bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1] + } +} + +// prefixLen returns the length of the common prefix of a and b. +func prefixLen(a, b []byte) int { + var i, length = 0, len(a) + if len(b) < length { + length = len(b) + } + for ; i < length; i++ { + if a[i] != b[i] { + break + } + } + + return i +} + +// hasTerm returns whether a hex key has the terminator flag. +func hasTerm(s []byte) bool { + return len(s) > 0 && s[len(s)-1] == 16 +} diff --git a/statediff/statediff.go b/statediff/statediff.go new file mode 100644 index 000000000..d980ef867 --- /dev/null +++ b/statediff/statediff.go @@ -0,0 +1,67 @@ +package statediff + +import ( + "encoding/json" + "math/big" + + "github.com/ethereum/go-ethereum/common" +) + +type StateDiff struct { + BlockNumber int64 `json:"blockNumber" gencodec:"required"` + BlockHash common.Hash `json:"blockHash" gencodec:"required"` + CreatedAccounts map[common.Address]AccountDiffEventual `json:"createdAccounts" gencodec:"required"` + DeletedAccounts map[common.Address]AccountDiffEventual `json:"deletedAccounts" gencodec:"required"` + UpdatedAccounts map[common.Address]AccountDiffIncremental `json:"updatedAccounts" gencodec:"required"` + + encoded []byte + err error +} + +func (self *StateDiff) ensureEncoded() { + if self.encoded == nil && self.err == nil { + self.encoded, self.err = json.Marshal(self) + } +} + +// Implement Encoder interface for StateDiff +func (sd *StateDiff) Length() int { + sd.ensureEncoded() + return len(sd.encoded) +} + +// Implement Encoder interface for StateDiff +func (sd *StateDiff) Encode() ([]byte, error) { + sd.ensureEncoded() + return sd.encoded, sd.err +} + +type AccountDiffEventual struct { + Nonce diffUint64 `json:"nonce" gencodec:"required"` + Balance diffBigInt `json:"balance" gencodec:"required"` + Code string `json:"code" gencodec:"required"` + CodeHash string `json:"codeHash" gencodec:"required"` + ContractRoot diffString `json:"contractRoot" gencodec:"required"` + Storage map[string]diffString `json:"storage" gencodec:"required"` +} + +type AccountDiffIncremental struct { + Nonce diffUint64 `json:"nonce" gencodec:"required"` + Balance diffBigInt `json:"balance" gencodec:"required"` + CodeHash string `json:"codeHash" gencodec:"required"` + ContractRoot diffString `json:"contractRoot" gencodec:"required"` + Storage map[string]diffString `json:"storage" gencodec:"required"` +} + +type diffString struct { + NewValue *string `json:"newValue" gencodec:"optional"` + OldValue *string `json:"oldValue" gencodec:"optional"` +} +type diffUint64 struct { + NewValue *uint64 `json:"newValue" gencodec:"optional"` + OldValue *uint64 `json:"oldValue" gencodec:"optional"` +} +type diffBigInt struct { + NewValue *big.Int `json:"newValue" gencodec:"optional"` + OldValue *big.Int `json:"oldValue" gencodec:"optional"` +} diff --git a/statediff/statediff_builder.go b/statediff/statediff_builder.go new file mode 100644 index 000000000..c3011ed66 --- /dev/null +++ b/statediff/statediff_builder.go @@ -0,0 +1,301 @@ +package statediff + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +type StateDiffBuilder interface { + CreateStateDiff(oldStateRoot, newStateRoot common.Hash, blockNumber int64, blockHash common.Hash) (*StateDiff, error) +} + +type stateDiffBuilder struct { + chainDB ethdb.Database + trieDB *trie.Database + cachedTrie *trie.Trie +} + +func NewStateDiffBuilder(db ethdb.Database) *stateDiffBuilder { + return &stateDiffBuilder{ + chainDB: db, + trieDB: trie.NewDatabase(db), + } +} + +func (sdb *stateDiffBuilder) CreateStateDiff(oldStateRoot, newStateRoot common.Hash, blockNumber int64, blockHash common.Hash) (*StateDiff, error) { + // Generate tries for old and new states + oldTrie, err := trie.New(oldStateRoot, sdb.trieDB) + if err != nil { + return nil, err + } + newTrie, err := trie.New(newStateRoot, sdb.trieDB) + if err != nil { + return nil, err + } + + // Find created accounts + oldIt := oldTrie.NodeIterator([]byte{}) + newIt := newTrie.NodeIterator([]byte{}) + creations, err := sdb.collectDiffNodes(oldIt, newIt) + if err != nil { + return nil, err + } + + // Find deleted accounts + oldIt = oldTrie.NodeIterator(make([]byte, 0)) + newIt = newTrie.NodeIterator(make([]byte, 0)) + deletions, err := sdb.collectDiffNodes(newIt, oldIt) + if err != nil { + return nil, err + } + + // Find all the diffed keys + createKeys := sortKeys(creations) + deleteKeys := sortKeys(deletions) + updatedKeys := findIntersection(createKeys, deleteKeys) + + // Build and return the statediff + updatedAccounts, err := sdb.buildDiffIncremental(creations, deletions, &updatedKeys) + if err != nil { + return nil, err + } + createdAccounts, err := sdb.buildDiffEventual(creations, true) + if err != nil { + return nil, err + } + deletedAccounts, err := sdb.buildDiffEventual(deletions, false) + if err != nil { + return nil, err + } + + return &StateDiff{ + BlockNumber: blockNumber, + BlockHash: blockHash, + CreatedAccounts: createdAccounts, + DeletedAccounts: deletedAccounts, + UpdatedAccounts: updatedAccounts, + }, nil +} + +func (sdb *stateDiffBuilder) collectDiffNodes(a, b trie.NodeIterator) (map[common.Address]*state.Account, error) { + var diffAccounts map[common.Address]*state.Account + it, _ := trie.NewDifferenceIterator(a, b) + + for { + log.Debug("Current Path and Hash", "path", pathToStr(it), "hashold", common.Hash(it.Hash())) + if it.Leaf() { + + // lookup address + path := make([]byte, len(it.Path())-1) + copy(path, it.Path()) + addr, err := sdb.addressByPath(path) + if err != nil { + log.Error("Error looking up address via path", "path", path, "error", err) + return nil, err + } + + // lookup account state + var account state.Account + if err := rlp.DecodeBytes(it.LeafBlob(), &account); err != nil { + log.Error("Error looking up account via address", "address", addr, "error", err) + return nil, err + } + + // record account to diffs (creation if we are looking at new - old; deletion if old - new) + log.Debug("Account lookup successful", "address", addr, "account", account) + diffAccounts[*addr] = &account + } + cont := it.Next(true) + if !cont { + break + } + } + return diffAccounts, nil +} + +func (sdb *stateDiffBuilder) buildDiffEventual(accounts map[common.Address]*state.Account, created bool) (map[common.Address]AccountDiffEventual, error) { + accountDiffs := make(map[common.Address]AccountDiffEventual) + for addr, val := range accounts { + sr := val.Root + if storageDiffs, err := sdb.buildStorageDiffsEventual(sr, created); err != nil { + log.Error("Failed building eventual storage diffs", "Address", val, "error", err) + return nil, err + } else { + code := "" + codeBytes, err := sdb.chainDB.Get(val.CodeHash) + if err == nil && len(codeBytes) != 0 { + code = common.ToHex(codeBytes) + } else { + log.Debug("No code field.", "codehash", val.CodeHash, "Address", val, "error", err) + } + codeHash := common.ToHex(val.CodeHash) + if created { + nonce := diffUint64{ + NewValue: &val.Nonce, + } + + balance := diffBigInt{ + NewValue: val.Balance, + } + + hexRoot := val.Root.Hex() + contractRoot := diffString{ + NewValue: &hexRoot, + } + accountDiffs[addr] = AccountDiffEventual{ + Nonce: nonce, + Balance: balance, + CodeHash: codeHash, + Code: code, + ContractRoot: contractRoot, + Storage: storageDiffs, + } + } else { + nonce := diffUint64{ + OldValue: &val.Nonce, + } + balance := diffBigInt{ + OldValue: val.Balance, + } + hexRoot := val.Root.Hex() + contractRoot := diffString{ + OldValue: &hexRoot, + } + accountDiffs[addr] = AccountDiffEventual{ + Nonce: nonce, + Balance: balance, + CodeHash: codeHash, + ContractRoot: contractRoot, + Storage: storageDiffs, + } + } + } + } + return accountDiffs, nil +} + +func (sdb *stateDiffBuilder) buildDiffIncremental(creations map[common.Address]*state.Account, deletions map[common.Address]*state.Account, updatedKeys *[]string) (map[common.Address]AccountDiffIncremental, error) { + updatedAccounts := make(map[common.Address]AccountDiffIncremental) + for _, val := range *updatedKeys { + createdAcc := creations[common.HexToAddress(val)] + deletedAcc := deletions[common.HexToAddress(val)] + oldSR := deletedAcc.Root + newSR := createdAcc.Root + if storageDiffs, err := sdb.buildStorageDiffsIncremental(oldSR, newSR); err != nil { + log.Error("Failed building storage diffs", "Address", val, "error", err) + return nil, err + } else { + nonce := diffUint64{ + NewValue: &createdAcc.Nonce, + OldValue: &deletedAcc.Nonce, + } + + balance := diffBigInt{ + NewValue: createdAcc.Balance, + OldValue: deletedAcc.Balance, + } + codeHash := common.ToHex(createdAcc.CodeHash) + + nHexRoot := createdAcc.Root.Hex() + oHexRoot := deletedAcc.Root.Hex() + contractRoot := diffString{ + NewValue: &nHexRoot, + OldValue: &oHexRoot, + } + + updatedAccounts[common.HexToAddress(val)] = AccountDiffIncremental{ + Nonce: nonce, + Balance: balance, + CodeHash: codeHash, + ContractRoot: contractRoot, + Storage: storageDiffs, + } + delete(creations, common.HexToAddress(val)) + delete(deletions, common.HexToAddress(val)) + } + } + return updatedAccounts, nil +} + +func (sdb *stateDiffBuilder) buildStorageDiffsEventual(sr common.Hash, creation bool) (map[string]diffString, error) { + log.Debug("Storage Root For Eventual Diff", "root", sr.Hex()) + sTrie, err := trie.New(sr, sdb.trieDB) + if err != nil { + return nil, err + } + it := sTrie.NodeIterator(make([]byte, 0)) + storageDiffs := make(map[string]diffString) + for { + log.Debug("Iterating over state at path ", "path", pathToStr(it)) + if it.Leaf() { + log.Debug("Found leaf in storage", "path", pathToStr(it)) + path := pathToStr(it) + value := common.ToHex(it.LeafBlob()) + if creation { + storageDiffs[path] = diffString{NewValue: &value} + } else { + storageDiffs[path] = diffString{OldValue: &value} + } + } + cont := it.Next(true) + if !cont { + break + } + } + return storageDiffs, nil +} + +func (sdb *stateDiffBuilder) buildStorageDiffsIncremental(oldSR common.Hash, newSR common.Hash) (map[string]diffString, error) { + log.Debug("Storage Roots for Incremental Diff", "old", oldSR.Hex(), "new", newSR.Hex()) + oldTrie, err := trie.New(oldSR, sdb.trieDB) + if err != nil { + return nil, err + } + newTrie, err := trie.New(newSR, sdb.trieDB) + if err != nil { + return nil, err + } + + oldIt := oldTrie.NodeIterator(make([]byte, 0)) + newIt := newTrie.NodeIterator(make([]byte, 0)) + it, _ := trie.NewDifferenceIterator(oldIt, newIt) + storageDiffs := make(map[string]diffString) + for { + if it.Leaf() { + log.Debug("Found leaf in storage", "path", pathToStr(it)) + path := pathToStr(it) + value := common.ToHex(it.LeafBlob()) + if oldVal, err := oldTrie.TryGet(it.LeafKey()); err != nil { + log.Error("Failed to look up value in oldTrie", "path", path, "error", err) + } else { + hexOldVal := common.ToHex(oldVal) + storageDiffs[path] = diffString{OldValue: &hexOldVal, NewValue: &value} + } + } + + cont := it.Next(true) + if !cont { + break + } + } + return storageDiffs, nil +} + +func (sdb *stateDiffBuilder) addressByPath(path []byte) (*common.Address, error) { + // db := core.PreimageTable(sdb.chainDb) + log.Debug("Looking up address from path", "path", common.ToHex(append([]byte("secure-key-"), path...))) + // if addrBytes,err := db.Get(path); err != nil { + if addrBytes, err := sdb.chainDB.Get(append([]byte("secure-key-"), hexToKeybytes(path)...)); err != nil { + log.Error("Error looking up address via path", "path", common.ToHex(append([]byte("secure-key-"), path...)), "error", err) + return nil, err + } else { + addr := common.BytesToAddress(addrBytes) + log.Debug("Address found", "Address", addr) + return &addr, nil + } + +}