diff --git a/pkg/api.go b/pkg/api.go index f22db38..334268e 100644 --- a/pkg/api.go +++ b/pkg/api.go @@ -17,6 +17,7 @@ package statediff import ( "context" + sd "github.com/ethereum/go-ethereum/statediff" ) diff --git a/pkg/builder.go b/pkg/builder.go index 5b8968e..742c1f7 100644 --- a/pkg/builder.go +++ b/pkg/builder.go @@ -22,8 +22,8 @@ package statediff import ( "bytes" "fmt" - "sync" "math/bits" + "sync" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" @@ -33,14 +33,15 @@ import ( "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" - iter "github.com/vulcanize/go-eth-state-node-iterator" sd "github.com/ethereum/go-ethereum/statediff" + iter "github.com/vulcanize/go-eth-state-node-iterator" ) var ( nullHashBytes = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000") emptyNode, _ = rlp.EncodeToBytes([]byte{}) emptyContractRoot = crypto.Keccak256Hash(emptyNode) + nullCodeHash = crypto.Keccak256Hash([]byte{}).Bytes() ) // Builder interface exposes the method for building a state diff between two blocks @@ -79,14 +80,15 @@ func (sdb *builder) BuildStateTrieObject(current *types.Block) (sd.StateObject, return sd.StateObject{}, fmt.Errorf("error creating trie for block %d: %v", current.Number(), err) } it := currentTrie.NodeIterator([]byte{}) - stateNodes, err := sdb.buildStateTrie(it) + stateNodes, codeAndCodeHashes, err := sdb.buildStateTrie(it) if err != nil { return sd.StateObject{}, fmt.Errorf("error collecting state nodes for block %d: %v", current.Number(), err) } return sd.StateObject{ - BlockNumber: current.Number(), - BlockHash: current.Hash(), - Nodes: stateNodes, + BlockNumber: current.Number(), + BlockHash: current.Hash(), + Nodes: stateNodes, + CodeAndCodeHashes: codeAndCodeHashes, }, nil } @@ -112,49 +114,76 @@ func resolveNode(it trie.NodeIterator, trieDB *trie.Database) (sd.StateNode, []i }, nodeElements, nil } -func (sdb *builder) buildStateTrie(it trie.NodeIterator) ([]sd.StateNode, error) { +func (sdb *builder) buildStateTrie(it trie.NodeIterator) ([]sd.StateNode, []sd.CodeAndCodeHash, error) { stateNodes := make([]sd.StateNode, 0) + codeAndCodeHashes := make([]sd.CodeAndCodeHash, 0) for it.Next(true) { // skip value nodes - if it.Leaf() || bytes.Equal(nullHashBytes, it.Hash().Bytes()) { + if it.Leaf() { continue } - node, nodeElements, err := resolveNode(it, sdb.stateCache.TrieDB()) - if err != nil { - return nil, err + if bytes.Equal(nullHashBytes, it.Hash().Bytes()) { + continue } - switch node.NodeType { + nodePath := make([]byte, len(it.Path())) + copy(nodePath, it.Path()) + node, err := sdb.stateCache.TrieDB().Node(it.Hash()) + if err != nil { + return nil, nil, err + } + var nodeElements []interface{} + if err := rlp.DecodeBytes(node, &nodeElements); err != nil { + return nil, nil, err + } + ty, err := sd.CheckKeyType(nodeElements) + if err != nil { + return nil, nil, err + } + switch ty { case sd.Leaf: var account state.Account if err := rlp.DecodeBytes(nodeElements[1].([]byte), &account); err != nil { - return nil, fmt.Errorf("error decoding account for leaf node at path %x nerror: %v", node.Path, err) + return nil, nil, fmt.Errorf("error decoding account for leaf node at path %x nerror: %v", nodePath, err) } partialPath := trie.CompactToHex(nodeElements[0].([]byte)) - valueNodePath := append(node.Path, partialPath...) + valueNodePath := append(nodePath, partialPath...) encodedPath := trie.HexToCompact(valueNodePath) leafKey := encodedPath[1:] - storageNodes, err := sdb.buildStorageNodesEventual(account.Root, nil, true) - if err != nil { - return nil, fmt.Errorf("failed building eventual storage diffs for account %+v\r\nerror: %v", account, err) + node := sd.StateNode{ + NodeType: ty, + Path: nodePath, + LeafKey: leafKey, + NodeValue: node, } - stateNodes = append(stateNodes, sd.StateNode{ - NodeType: node.NodeType, - Path: node.Path, - LeafKey: leafKey, - NodeValue: node.NodeValue, - StorageNodes: storageNodes, - }) + if !bytes.Equal(account.CodeHash, nullCodeHash) { + storageNodes, err := sdb.buildStorageNodesEventual(account.Root, nil, true) + if err != nil { + return nil, nil, fmt.Errorf("failed building eventual storage diffs for account %+v\r\nerror: %v", account, err) + } + node.StorageNodes = storageNodes + // emit codehash => code mappings for cod + codeHash := common.BytesToHash(account.CodeHash) + code, err := sdb.stateCache.ContractCode(common.Hash{}, codeHash) + if err != nil { + return nil, nil, fmt.Errorf("failed to retrieve code for codehash %s\r\n error: %v", codeHash.String(), err) + } + codeAndCodeHashes = append(codeAndCodeHashes, sd.CodeAndCodeHash{ + Hash: codeHash, + Code: code, + }) + } + stateNodes = append(stateNodes, node) case sd.Extension, sd.Branch: stateNodes = append(stateNodes, sd.StateNode{ - NodeType: node.NodeType, - Path: node.Path, - NodeValue: node.NodeValue, + NodeType: ty, + Path: nodePath, + NodeValue: node, }) default: - return nil, fmt.Errorf("unexpected node type %s", node.NodeType) + return nil, nil, fmt.Errorf("unexpected node type %s", ty) } } - return stateNodes, it.Error() + return stateNodes, codeAndCodeHashes, it.Error() } // BuildStateDiff builds a statediff object from two blocks and the provided parameters @@ -195,7 +224,12 @@ func (sdb *builder) BuildStateDiffObject(args sd.Args, params sd.Params) (sd.Sta } } - nodeChan := make(chan []sd.StateNode) + type packet struct { + nodes []sd.StateNode + codes []sd.CodeAndCodeHash + } + + packetChan := make(chan packet) var wg sync.WaitGroup for w := uint(0); w < sdb.numWorkers; w++ { @@ -203,49 +237,55 @@ func (sdb *builder) BuildStateDiffObject(args sd.Args, params sd.Params) (sd.Sta go func(iterChan <-chan []iterPair) error { defer wg.Done() if iters, more := <-iterChan; more { - subtrieNodes, err := sdb.buildStateDiff(iters, params) + subtrieNodes, subtrieCodes, err := sdb.buildStateDiff(iters, params) if err != nil { return err } - nodeChan <- subtrieNodes + packetChan <- packet{ + nodes: subtrieNodes, + codes: subtrieCodes, + } } return nil }(iterChan) } go func() { - defer close(nodeChan) + defer close(packetChan) defer close(iterChan) wg.Wait() }() stateNodes := make([]sd.StateNode, 0) - for subtrieNodes := range nodeChan { - stateNodes = append(stateNodes, subtrieNodes...) + codeAndCodeHashes := make([]sd.CodeAndCodeHash, 0) + for packet := range packetChan { + stateNodes = append(stateNodes, packet.nodes...) + codeAndCodeHashes = append(codeAndCodeHashes, packet.codes...) } return sd.StateObject{ - BlockHash: args.BlockHash, - BlockNumber: args.BlockNumber, - Nodes: stateNodes, + BlockHash: args.BlockHash, + BlockNumber: args.BlockNumber, + Nodes: stateNodes, + CodeAndCodeHashes: codeAndCodeHashes, }, nil } -func (sdb *builder) buildStateDiff(args []iterPair, params sd.Params) ([]sd.StateNode, error) { +func (sdb *builder) buildStateDiff(args []iterPair, params sd.Params) ([]sd.StateNode, []sd.CodeAndCodeHash, error) { // collect a slice of all the intermediate nodes that were touched and exist at B // a map of their leafkey to all the accounts that were touched and exist at B // and a slice of all the paths for the nodes in both of the above sets createdOrUpdatedIntermediateNodes, diffAccountsAtB, diffPathsAtB, err := sdb.createdAndUpdatedState( args[0], params.WatchedAddresses, params.IntermediateStateNodes) if err != nil { - return nil, fmt.Errorf("error collecting createdAndUpdatedNodes: %v", err) + return nil, nil, fmt.Errorf("error collecting createdAndUpdatedNodes: %v", err) } // collect a slice of all the nodes that existed at a path in A that doesn't exist in B // a map of their leafkey to all the accounts that were touched and exist at A emptiedPaths, diffAccountsAtA, err := sdb.deletedOrUpdatedState(args[1], diffPathsAtB) if err != nil { - return nil, fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) + return nil, nil, fmt.Errorf("error collecting deletedOrUpdatedNodes: %v", err) } // collect and sort the leafkey keys for both account mappings into a slice @@ -264,28 +304,24 @@ func (sdb *builder) buildStateDiff(args []iterPair, params sd.Params) ([]sd.Stat diffAccountsAtB, diffAccountsAtA, updatedKeys, params.WatchedStorageSlots, params.IntermediateStorageNodes) if err != nil { - return nil, fmt.Errorf("error building diff for updated accounts: %v", err) + return nil, nil, fmt.Errorf("error building diff for updated accounts: %v", err) } // build the diff nodes for created accounts - createdAccounts, err := sdb.buildAccountCreations( + createdAccounts, codeAndCodeHashes, err := sdb.buildAccountCreations( diffAccountsAtB, params.WatchedStorageSlots, params.IntermediateStorageNodes) if err != nil { - return nil, fmt.Errorf("error building diff for created accounts: %v", err) + return nil, nil, fmt.Errorf("error building diff for created accounts: %v", err) } // assemble all of the nodes into the statediff object, including the intermediate nodes - res := append( + nodes := append( append( append(updatedAccounts, createdAccounts...), createdOrUpdatedIntermediateNodes..., ), emptiedPaths...) - var paths [][]byte - for _, n := range res { - paths = append(paths, n.Path) - } - return res, nil + return nodes, codeAndCodeHashes, nil } // createdAndUpdatedState returns @@ -430,24 +466,39 @@ func (sdb *builder) buildAccountUpdates(creations, deletions AccountMap, updated } // buildAccountCreations returns the statediff node objects for all the accounts that exist at B but not at A -func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKeys []common.Hash, intermediateStorageNodes bool) ([]sd.StateNode, error) { +// it also returns the code and codehash for created contract accounts +func (sdb *builder) buildAccountCreations(accounts AccountMap, watchedStorageKeys []common.Hash, intermediateStorageNodes bool) ([]sd.StateNode, []sd.CodeAndCodeHash, error) { accountDiffs := make([]sd.StateNode, 0, len(accounts)) + codeAndCodeHashes := make([]sd.CodeAndCodeHash, 0) for _, val := range accounts { - // For account creations, any storage node contained is a diff - storageDiffs, err := sdb.buildStorageNodesEventual(val.Account.Root, watchedStorageKeys, intermediateStorageNodes) - if err != nil { - return nil, fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err) + diff := sd.StateNode{ + NodeType: val.NodeType, + Path: val.Path, + LeafKey: val.LeafKey, + NodeValue: val.NodeValue, } - accountDiffs = append(accountDiffs, sd.StateNode{ - NodeType: val.NodeType, - Path: val.Path, - LeafKey: val.LeafKey, - NodeValue: val.NodeValue, - StorageNodes: storageDiffs, - }) + if !bytes.Equal(val.Account.CodeHash, nullCodeHash) { + // For contract creations, any storage node contained is a diff + storageDiffs, err := sdb.buildStorageNodesEventual(val.Account.Root, watchedStorageKeys, intermediateStorageNodes) + if err != nil { + return nil, nil, fmt.Errorf("failed building eventual storage diffs for node %x\r\nerror: %v", val.Path, err) + } + diff.StorageNodes = storageDiffs + // emit codehash => code mappings for new contracts + codeHash := common.BytesToHash(val.Account.CodeHash) + code, err := sdb.stateCache.ContractCode(common.Hash{}, codeHash) + if err != nil { + return nil, nil, fmt.Errorf("failed to retrieve code for codehash %s\r\n error: %v", codeHash.String(), err) + } + codeAndCodeHashes = append(codeAndCodeHashes, sd.CodeAndCodeHash{ + Hash: codeHash, + Code: code, + }) + } + accountDiffs = append(accountDiffs, diff) } - return accountDiffs, nil + return accountDiffs, codeAndCodeHashes, nil } // buildStorageNodesEventual builds the storage diff node objects for a created account