From 55599ee95d4151a2502465e0afc7c47bd1acba77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= Date: Mon, 5 Feb 2018 18:40:32 +0200 Subject: [PATCH] core, trie: intermediate mempool between trie and database (#15857) This commit reduces database I/O by not writing every state trie to disk. --- accounts/abi/bind/backends/simulated.go | 14 +- cmd/evm/runner.go | 4 +- cmd/geth/chaincmd.go | 4 +- cmd/geth/main.go | 3 + cmd/geth/usage.go | 6 +- cmd/utils/cmd.go | 26 +- cmd/utils/flags.go | 47 +++- common/size.go | 27 +- consensus/errors.go | 4 + core/bench_test.go | 4 +- core/block_validator.go | 9 +- core/block_validator_test.go | 8 +- core/blockchain.go | 336 +++++++++++++++++----- core/blockchain_test.go | 145 ++++++++-- core/chain_indexer.go | 3 + core/chain_makers.go | 9 +- core/chain_makers_test.go | 2 +- core/dao_test.go | 24 +- core/genesis.go | 24 +- core/genesis_test.go | 6 +- core/state/database.go | 51 +++- core/state/iterator_test.go | 17 +- core/state/state_object.go | 5 +- core/state/state_test.go | 6 +- core/state/statedb.go | 38 ++- core/state/statedb_test.go | 14 +- core/state/sync_test.go | 44 +-- core/tx_pool_test.go | 4 +- core/types/block.go | 9 + core/types/receipt.go | 13 + core/types/transaction.go | 2 + eth/api.go | 4 +- eth/api_tracer.go | 135 ++------- eth/backend.go | 8 +- eth/config.go | 8 +- eth/downloader/downloader.go | 317 +++++++++++---------- eth/downloader/downloader_test.go | 194 ++++--------- eth/downloader/queue.go | 171 ++++++------ eth/downloader/statesync.go | 31 +-- eth/handler.go | 6 +- eth/handler_test.go | 16 +- eth/helper_test.go | 14 +- eth/protocol_test.go | 6 +- eth/sync_test.go | 4 +- internal/ethapi/api.go | 2 +- les/handler.go | 197 +++++++------ les/handler_test.go | 2 +- les/helper_test.go | 2 +- les/odr_test.go | 1 - light/lightchain.go | 7 + light/nodeset.go | 8 +- light/odr_test.go | 4 +- light/postprocess.go | 64 +++-- light/trie.go | 18 +- light/trie_test.go | 2 +- light/txpool_test.go | 2 +- miner/worker.go | 2 +- tests/block_test_util.go | 2 +- tests/state_test_util.go | 6 +- trie/database.go | 355 ++++++++++++++++++++++++ trie/hasher.go | 61 +++- trie/iterator_test.go | 125 ++++++--- trie/proof.go | 47 ++-- trie/secure_trie.go | 62 ++--- trie/secure_trie_test.go | 20 +- trie/sync.go | 14 +- trie/sync_test.go | 103 ++++--- trie/trie.go | 90 +++--- trie/trie_test.go | 104 +++---- 69 files changed, 1958 insertions(+), 1164 deletions(-) create mode 100644 trie/database.go diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 1803d3f23..bd342a8cb 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -68,7 +68,7 @@ func NewSimulatedBackend(alloc core.GenesisAlloc) *SimulatedBackend { database, _ := ethdb.NewMemDatabase() genesis := core.Genesis{Config: params.AllEthashProtocolChanges, Alloc: alloc} genesis.MustCommit(database) - blockchain, _ := core.NewBlockChain(database, genesis.Config, ethash.NewFaker(), vm.Config{}) + blockchain, _ := core.NewBlockChain(database, nil, genesis.Config, ethash.NewFaker(), vm.Config{}) backend := &SimulatedBackend{ database: database, @@ -102,8 +102,10 @@ func (b *SimulatedBackend) Rollback() { func (b *SimulatedBackend) rollback() { blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), ethash.NewFaker(), b.database, 1, func(int, *core.BlockGen) {}) + statedb, _ := b.blockchain.State() + b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database)) + b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database()) } // CodeAt returns the code associated with a certain account in the blockchain. @@ -309,8 +311,10 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa } block.AddTx(tx) }) + statedb, _ := b.blockchain.State() + b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database)) + b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database()) return nil } @@ -386,8 +390,10 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error { } block.OffsetTime(int64(adjustment.Seconds())) }) + statedb, _ := b.blockchain.State() + b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database)) + b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database()) return nil } diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index 96de0c76a..a9a2e5420 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -96,7 +96,9 @@ func runCmd(ctx *cli.Context) error { } if ctx.GlobalString(GenesisFlag.Name) != "" { gen := readGenesis(ctx.GlobalString(GenesisFlag.Name)) - _, statedb = gen.ToBlock() + db, _ := ethdb.NewMemDatabase() + genesis := gen.ToBlock(db) + statedb, _ = state.New(genesis.Root(), state.NewDatabase(db)) chainConfig = gen.Config } else { db, _ := ethdb.NewMemDatabase() diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index 4a9a7b11b..35bf576e1 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -202,7 +202,7 @@ func importChain(ctx *cli.Context) error { if len(ctx.Args()) == 1 { if err := utils.ImportChain(chain, ctx.Args().First()); err != nil { - utils.Fatalf("Import error: %v", err) + log.Error("Import error", "err", err) } } else { for _, arg := range ctx.Args() { @@ -211,7 +211,7 @@ func importChain(ctx *cli.Context) error { } } } - + chain.Stop() fmt.Printf("Import done in %v.\n\n", time.Since(start)) // Output pre-compaction stats mostly to see the import trashing diff --git a/cmd/geth/main.go b/cmd/geth/main.go index b955bd243..cb8d63bf7 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -85,10 +85,13 @@ var ( utils.FastSyncFlag, utils.LightModeFlag, utils.SyncModeFlag, + utils.GCModeFlag, utils.LightServFlag, utils.LightPeersFlag, utils.LightKDFFlag, utils.CacheFlag, + utils.CacheDatabaseFlag, + utils.CacheGCFlag, utils.TrieCacheGenFlag, utils.ListenPortFlag, utils.MaxPeersFlag, diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go index a834d5b7a..a2bcaff02 100644 --- a/cmd/geth/usage.go +++ b/cmd/geth/usage.go @@ -22,10 +22,11 @@ import ( "io" "sort" + "strings" + "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/internal/debug" "gopkg.in/urfave/cli.v1" - "strings" ) // AppHelpTemplate is the test template for the default, global app help topic. @@ -74,6 +75,7 @@ var AppHelpFlagGroups = []flagGroup{ utils.TestnetFlag, utils.RinkebyFlag, utils.SyncModeFlag, + utils.GCModeFlag, utils.EthStatsURLFlag, utils.IdentityFlag, utils.LightServFlag, @@ -127,6 +129,8 @@ var AppHelpFlagGroups = []flagGroup{ Name: "PERFORMANCE TUNING", Flags: []cli.Flag{ utils.CacheFlag, + utils.CacheDatabaseFlag, + utils.CacheGCFlag, utils.TrieCacheGenFlag, }, }, diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index 23b10c2d7..53cdf7861 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -116,7 +116,6 @@ func ImportChain(chain *core.BlockChain, fn string) error { return err } } - stream := rlp.NewStream(reader, 0) // Run actual the import. @@ -150,25 +149,34 @@ func ImportChain(chain *core.BlockChain, fn string) error { if checkInterrupt() { return fmt.Errorf("interrupted") } - if hasAllBlocks(chain, blocks[:i]) { + missing := missingBlocks(chain, blocks[:i]) + if len(missing) == 0 { log.Info("Skipping batch as all blocks present", "batch", batch, "first", blocks[0].Hash(), "last", blocks[i-1].Hash()) continue } - - if _, err := chain.InsertChain(blocks[:i]); err != nil { + if _, err := chain.InsertChain(missing); err != nil { return fmt.Errorf("invalid block %d: %v", n, err) } } return nil } -func hasAllBlocks(chain *core.BlockChain, bs []*types.Block) bool { - for _, b := range bs { - if !chain.HasBlock(b.Hash(), b.NumberU64()) { - return false +func missingBlocks(chain *core.BlockChain, blocks []*types.Block) []*types.Block { + head := chain.CurrentBlock() + for i, block := range blocks { + // If we're behind the chain head, only check block, state is available at head + if head.NumberU64() > block.NumberU64() { + if !chain.HasBlock(block.Hash(), block.NumberU64()) { + return blocks[i:] + } + continue + } + // If we're above the chain head, state availability is a must + if !chain.HasBlockAndState(block.Hash(), block.NumberU64()) { + return blocks[i:] } } - return true + return nil } func ExportChain(blockchain *core.BlockChain, fn string) error { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 833cd95de..2a2909ff2 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -170,7 +170,11 @@ var ( Usage: `Blockchain sync mode ("fast", "full", or "light")`, Value: &defaultSyncMode, } - + GCModeFlag = cli.StringFlag{ + Name: "gcmode", + Usage: `Blockchain garbage collection mode ("full", "archive")`, + Value: "full", + } LightServFlag = cli.IntFlag{ Name: "lightserv", Usage: "Maximum percentage of time allowed for serving LES requests (0-90)", @@ -293,8 +297,18 @@ var ( // Performance tuning settings CacheFlag = cli.IntFlag{ Name: "cache", - Usage: "Megabytes of memory allocated to internal caching (min 16MB / database forced)", - Value: 128, + Usage: "Megabytes of memory allocated to internal caching", + Value: 1024, + } + CacheDatabaseFlag = cli.IntFlag{ + Name: "cache.database", + Usage: "Percentage of cache memory allowance to use for database io", + Value: 75, + } + CacheGCFlag = cli.IntFlag{ + Name: "cache.gc", + Usage: "Percentage of cache memory allowance to use for trie pruning", + Value: 25, } TrieCacheGenFlag = cli.IntFlag{ Name: "trie-cache-gens", @@ -1021,11 +1035,19 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { cfg.NetworkId = ctx.GlobalUint64(NetworkIdFlag.Name) } - if ctx.GlobalIsSet(CacheFlag.Name) { - cfg.DatabaseCache = ctx.GlobalInt(CacheFlag.Name) + if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheDatabaseFlag.Name) { + cfg.DatabaseCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheDatabaseFlag.Name) / 100 } cfg.DatabaseHandles = makeDatabaseHandles() + if gcmode := ctx.GlobalString(GCModeFlag.Name); gcmode != "full" && gcmode != "archive" { + Fatalf("--%s must be either 'full' or 'archive'", GCModeFlag.Name) + } + cfg.NoPruning = ctx.GlobalString(GCModeFlag.Name) == "archive" + + if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheGCFlag.Name) { + cfg.TrieCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheGCFlag.Name) / 100 + } if ctx.GlobalIsSet(MinerThreadsFlag.Name) { cfg.MinerThreads = ctx.GlobalInt(MinerThreadsFlag.Name) } @@ -1157,7 +1179,7 @@ func SetupNetwork(ctx *cli.Context) { // MakeChainDatabase open an LevelDB using the flags passed to the client and will hard crash if it fails. func MakeChainDatabase(ctx *cli.Context, stack *node.Node) ethdb.Database { var ( - cache = ctx.GlobalInt(CacheFlag.Name) + cache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheDatabaseFlag.Name) / 100 handles = makeDatabaseHandles() ) name := "chaindata" @@ -1209,8 +1231,19 @@ func MakeChain(ctx *cli.Context, stack *node.Node) (chain *core.BlockChain, chai }) } } + if gcmode := ctx.GlobalString(GCModeFlag.Name); gcmode != "full" && gcmode != "archive" { + Fatalf("--%s must be either 'full' or 'archive'", GCModeFlag.Name) + } + cache := &core.CacheConfig{ + Disabled: ctx.GlobalString(GCModeFlag.Name) == "archive", + TrieNodeLimit: eth.DefaultConfig.TrieCache, + TrieTimeLimit: eth.DefaultConfig.TrieTimeout, + } + if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheGCFlag.Name) { + cache.TrieNodeLimit = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheGCFlag.Name) / 100 + } vmcfg := vm.Config{EnablePreimageRecording: ctx.GlobalBool(VMEnableDebugFlag.Name)} - chain, err = core.NewBlockChain(chainDb, config, engine, vmcfg) + chain, err = core.NewBlockChain(chainDb, cache, config, engine, vmcfg) if err != nil { Fatalf("Can't create BlockChain: %v", err) } diff --git a/common/size.go b/common/size.go index c5a0cb0f2..bd0fc85c7 100644 --- a/common/size.go +++ b/common/size.go @@ -20,18 +20,29 @@ import ( "fmt" ) +// StorageSize is a wrapper around a float value that supports user friendly +// formatting. type StorageSize float64 -func (self StorageSize) String() string { - if self > 1000000 { - return fmt.Sprintf("%.2f mB", self/1000000) - } else if self > 1000 { - return fmt.Sprintf("%.2f kB", self/1000) +// String implements the stringer interface. +func (s StorageSize) String() string { + if s > 1000000 { + return fmt.Sprintf("%.2f mB", s/1000000) + } else if s > 1000 { + return fmt.Sprintf("%.2f kB", s/1000) } else { - return fmt.Sprintf("%.2f B", self) + return fmt.Sprintf("%.2f B", s) } } -func (self StorageSize) Int64() int64 { - return int64(self) +// TerminalString implements log.TerminalStringer, formatting a string for console +// output during logging. +func (s StorageSize) TerminalString() string { + if s > 1000000 { + return fmt.Sprintf("%.2fmB", s/1000000) + } else if s > 1000 { + return fmt.Sprintf("%.2fkB", s/1000) + } else { + return fmt.Sprintf("%.2fB", s) + } } diff --git a/consensus/errors.go b/consensus/errors.go index 3b136dbdd..a005c5f63 100644 --- a/consensus/errors.go +++ b/consensus/errors.go @@ -23,6 +23,10 @@ var ( // that is unknown. ErrUnknownAncestor = errors.New("unknown ancestor") + // ErrPrunedAncestor is returned when validating a block requires an ancestor + // that is known, but the state of which is not available. + ErrPrunedAncestor = errors.New("pruned ancestor") + // ErrFutureBlock is returned when a block's timestamp is in the future according // to the current node. ErrFutureBlock = errors.New("block in the future") diff --git a/core/bench_test.go b/core/bench_test.go index f976331d1..e23f0d19d 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -173,7 +173,7 @@ func benchInsertChain(b *testing.B, disk bool, gen func(int, *BlockGen)) { // Time the insertion of the new chain. // State and blocks are stored in the same DB. - chainman, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) + chainman, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) defer chainman.Stop() b.ReportAllocs() b.ResetTimer() @@ -283,7 +283,7 @@ func benchReadChain(b *testing.B, full bool, count uint64) { if err != nil { b.Fatalf("error opening database at %v: %v", dir, err) } - chain, err := NewBlockChain(db, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) + chain, err := NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) if err != nil { b.Fatalf("error creating chain: %v", err) } diff --git a/core/block_validator.go b/core/block_validator.go index 143728bb8..98958809b 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -50,11 +50,14 @@ func NewBlockValidator(config *params.ChainConfig, blockchain *BlockChain, engin // validated at this point. func (v *BlockValidator) ValidateBody(block *types.Block) error { // Check whether the block's known, and if not, that it's linkable - if v.bc.HasBlockAndState(block.Hash()) { + if v.bc.HasBlockAndState(block.Hash(), block.NumberU64()) { return ErrKnownBlock } - if !v.bc.HasBlockAndState(block.ParentHash()) { - return consensus.ErrUnknownAncestor + if !v.bc.HasBlockAndState(block.ParentHash(), block.NumberU64()-1) { + if !v.bc.HasBlock(block.ParentHash(), block.NumberU64()-1) { + return consensus.ErrUnknownAncestor + } + return consensus.ErrPrunedAncestor } // Header validity is known at this point, check the uncles and transactions header := block.Header() diff --git a/core/block_validator_test.go b/core/block_validator_test.go index e668601f3..e334b3c3c 100644 --- a/core/block_validator_test.go +++ b/core/block_validator_test.go @@ -42,7 +42,7 @@ func TestHeaderVerification(t *testing.T) { headers[i] = block.Header() } // Run the header checker for blocks one-by-one, checking for both valid and invalid nonces - chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) + chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) defer chain.Stop() for i := 0; i < len(blocks); i++ { @@ -106,11 +106,11 @@ func testHeaderConcurrentVerification(t *testing.T, threads int) { var results <-chan error if valid { - chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) + chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}) _, results = chain.engine.VerifyHeaders(chain, headers, seals) chain.Stop() } else { - chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFakeFailer(uint64(len(headers)-1)), vm.Config{}) + chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFakeFailer(uint64(len(headers)-1)), vm.Config{}) _, results = chain.engine.VerifyHeaders(chain, headers, seals) chain.Stop() } @@ -173,7 +173,7 @@ func testHeaderConcurrentAbortion(t *testing.T, threads int) { defer runtime.GOMAXPROCS(old) // Start the verifications and immediately abort - chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFakeDelayer(time.Millisecond), vm.Config{}) + chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFakeDelayer(time.Millisecond), vm.Config{}) defer chain.Stop() abort, results := chain.engine.VerifyHeaders(chain, headers, seals) diff --git a/core/blockchain.go b/core/blockchain.go index d5e139e31..8d141fddb 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -42,6 +42,7 @@ import ( "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" "github.com/hashicorp/golang-lru" + "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) var ( @@ -56,11 +57,20 @@ const ( maxFutureBlocks = 256 maxTimeFutureBlocks = 30 badBlockLimit = 10 + triesInMemory = 128 // BlockChainVersion ensures that an incompatible database forces a resync from scratch. BlockChainVersion = 3 ) +// CacheConfig contains the configuration values for the trie caching/pruning +// that's resident in a blockchain. +type CacheConfig struct { + Disabled bool // Whether to disable trie write caching (archive node) + TrieNodeLimit int // Memory limit (MB) at which to flush the current in-memory trie to disk + TrieTimeLimit time.Duration // Time limit after which to flush the current in-memory trie to disk +} + // BlockChain represents the canonical chain given a database with a genesis // block. The Blockchain manages chain imports, reverts, chain reorganisations. // @@ -76,10 +86,14 @@ const ( // included in the canonical one where as GetBlockByNumber always represents the // canonical chain. type BlockChain struct { - config *params.ChainConfig // chain & network configuration + chainConfig *params.ChainConfig // Chain & network configuration + cacheConfig *CacheConfig // Cache configuration for pruning + + db ethdb.Database // Low level persistent database to store final content in + triegc *prque.Prque // Priority queue mapping block numbers to tries to gc + gcproc time.Duration // Accumulates canonical block processing for trie dumping hc *HeaderChain - chainDb ethdb.Database rmLogsFeed event.Feed chainFeed event.Feed chainSideFeed event.Feed @@ -119,7 +133,13 @@ type BlockChain struct { // NewBlockChain returns a fully initialised block chain using information // available in the database. It initialises the default Ethereum Validator and // Processor. -func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config) (*BlockChain, error) { +func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config) (*BlockChain, error) { + if cacheConfig == nil { + cacheConfig = &CacheConfig{ + TrieNodeLimit: 256 * 1024 * 1024, + TrieTimeLimit: 5 * time.Minute, + } + } bodyCache, _ := lru.New(bodyCacheLimit) bodyRLPCache, _ := lru.New(bodyCacheLimit) blockCache, _ := lru.New(blockCacheLimit) @@ -127,9 +147,11 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co badBlocks, _ := lru.New(badBlockLimit) bc := &BlockChain{ - config: config, - chainDb: chainDb, - stateCache: state.NewDatabase(chainDb), + chainConfig: chainConfig, + cacheConfig: cacheConfig, + db: db, + triegc: prque.New(), + stateCache: state.NewDatabase(db), quit: make(chan struct{}), bodyCache: bodyCache, bodyRLPCache: bodyRLPCache, @@ -139,11 +161,11 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co vmConfig: vmConfig, badBlocks: badBlocks, } - bc.SetValidator(NewBlockValidator(config, bc, engine)) - bc.SetProcessor(NewStateProcessor(config, bc, engine)) + bc.SetValidator(NewBlockValidator(chainConfig, bc, engine)) + bc.SetProcessor(NewStateProcessor(chainConfig, bc, engine)) var err error - bc.hc, err = NewHeaderChain(chainDb, config, engine, bc.getProcInterrupt) + bc.hc, err = NewHeaderChain(db, chainConfig, engine, bc.getProcInterrupt) if err != nil { return nil, err } @@ -180,7 +202,7 @@ func (bc *BlockChain) getProcInterrupt() bool { // assumes that the chain manager mutex is held. func (bc *BlockChain) loadLastState() error { // Restore the last known head block - head := GetHeadBlockHash(bc.chainDb) + head := GetHeadBlockHash(bc.db) if head == (common.Hash{}) { // Corrupt or empty database, init from scratch log.Warn("Empty database, resetting chain") @@ -196,15 +218,17 @@ func (bc *BlockChain) loadLastState() error { // Make sure the state associated with the block is available if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil { // Dangling block without a state associated, init from scratch - log.Warn("Head state missing, resetting chain", "number", currentBlock.Number(), "hash", currentBlock.Hash()) - return bc.Reset() + log.Warn("Head state missing, repairing chain", "number", currentBlock.Number(), "hash", currentBlock.Hash()) + if err := bc.repair(¤tBlock); err != nil { + return err + } } // Everything seems to be fine, set as the head block bc.currentBlock = currentBlock // Restore the last known head header currentHeader := bc.currentBlock.Header() - if head := GetHeadHeaderHash(bc.chainDb); head != (common.Hash{}) { + if head := GetHeadHeaderHash(bc.db); head != (common.Hash{}) { if header := bc.GetHeaderByHash(head); header != nil { currentHeader = header } @@ -213,7 +237,7 @@ func (bc *BlockChain) loadLastState() error { // Restore the last known head fast block bc.currentFastBlock = bc.currentBlock - if head := GetHeadFastBlockHash(bc.chainDb); head != (common.Hash{}) { + if head := GetHeadFastBlockHash(bc.db); head != (common.Hash{}) { if block := bc.GetBlockByHash(head); block != nil { bc.currentFastBlock = block } @@ -243,7 +267,7 @@ func (bc *BlockChain) SetHead(head uint64) error { // Rewind the header chain, deleting all block bodies until then delFn := func(hash common.Hash, num uint64) { - DeleteBody(bc.chainDb, hash, num) + DeleteBody(bc.db, hash, num) } bc.hc.SetHead(head, delFn) currentHeader := bc.hc.CurrentHeader() @@ -275,10 +299,10 @@ func (bc *BlockChain) SetHead(head uint64) error { if bc.currentFastBlock == nil { bc.currentFastBlock = bc.genesisBlock } - if err := WriteHeadBlockHash(bc.chainDb, bc.currentBlock.Hash()); err != nil { + if err := WriteHeadBlockHash(bc.db, bc.currentBlock.Hash()); err != nil { log.Crit("Failed to reset head full block", "err", err) } - if err := WriteHeadFastBlockHash(bc.chainDb, bc.currentFastBlock.Hash()); err != nil { + if err := WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash()); err != nil { log.Crit("Failed to reset head fast block", "err", err) } return bc.loadLastState() @@ -292,7 +316,7 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error { if block == nil { return fmt.Errorf("non existent block [%x…]", hash[:4]) } - if _, err := trie.NewSecure(block.Root(), bc.chainDb, 0); err != nil { + if _, err := trie.NewSecure(block.Root(), bc.stateCache.TrieDB(), 0); err != nil { return err } // If all checks out, manually set the head block @@ -387,7 +411,7 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error { if err := bc.hc.WriteTd(genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil { log.Crit("Failed to write genesis block TD", "err", err) } - if err := WriteBlock(bc.chainDb, genesis); err != nil { + if err := WriteBlock(bc.db, genesis); err != nil { log.Crit("Failed to write genesis block", "err", err) } bc.genesisBlock = genesis @@ -400,6 +424,24 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error { return nil } +// repair tries to repair the current blockchain by rolling back the current block +// until one with associated state is found. This is needed to fix incomplete db +// writes caused either by crashes/power outages, or simply non-committed tries. +// +// This method only rolls back the current block. The current header and current +// fast block are left intact. +func (bc *BlockChain) repair(head **types.Block) error { + for { + // Abort if we've rewound to a head block that does have associated state + if _, err := state.New((*head).Root(), bc.stateCache); err == nil { + log.Info("Rewound blockchain to past state", "number", (*head).Number(), "hash", (*head).Hash()) + return nil + } + // Otherwise rewind one block and recheck state availability there + (*head) = bc.GetBlock((*head).ParentHash(), (*head).NumberU64()-1) + } +} + // Export writes the active chain to the given writer. func (bc *BlockChain) Export(w io.Writer) error { return bc.ExportN(w, uint64(0), bc.currentBlock.NumberU64()) @@ -437,13 +479,13 @@ func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error { // Note, this function assumes that the `mu` mutex is held! func (bc *BlockChain) insert(block *types.Block) { // If the block is on a side chain or an unknown one, force other heads onto it too - updateHeads := GetCanonicalHash(bc.chainDb, block.NumberU64()) != block.Hash() + updateHeads := GetCanonicalHash(bc.db, block.NumberU64()) != block.Hash() // Add the block to the canonical chain number scheme and mark as the head - if err := WriteCanonicalHash(bc.chainDb, block.Hash(), block.NumberU64()); err != nil { + if err := WriteCanonicalHash(bc.db, block.Hash(), block.NumberU64()); err != nil { log.Crit("Failed to insert block number", "err", err) } - if err := WriteHeadBlockHash(bc.chainDb, block.Hash()); err != nil { + if err := WriteHeadBlockHash(bc.db, block.Hash()); err != nil { log.Crit("Failed to insert head block hash", "err", err) } bc.currentBlock = block @@ -452,7 +494,7 @@ func (bc *BlockChain) insert(block *types.Block) { if updateHeads { bc.hc.SetCurrentHeader(block.Header()) - if err := WriteHeadFastBlockHash(bc.chainDb, block.Hash()); err != nil { + if err := WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil { log.Crit("Failed to insert head fast block hash", "err", err) } bc.currentFastBlock = block @@ -472,7 +514,7 @@ func (bc *BlockChain) GetBody(hash common.Hash) *types.Body { body := cached.(*types.Body) return body } - body := GetBody(bc.chainDb, hash, bc.hc.GetBlockNumber(hash)) + body := GetBody(bc.db, hash, bc.hc.GetBlockNumber(hash)) if body == nil { return nil } @@ -488,7 +530,7 @@ func (bc *BlockChain) GetBodyRLP(hash common.Hash) rlp.RawValue { if cached, ok := bc.bodyRLPCache.Get(hash); ok { return cached.(rlp.RawValue) } - body := GetBodyRLP(bc.chainDb, hash, bc.hc.GetBlockNumber(hash)) + body := GetBodyRLP(bc.db, hash, bc.hc.GetBlockNumber(hash)) if len(body) == 0 { return nil } @@ -502,21 +544,25 @@ func (bc *BlockChain) HasBlock(hash common.Hash, number uint64) bool { if bc.blockCache.Contains(hash) { return true } - ok, _ := bc.chainDb.Has(blockBodyKey(hash, number)) + ok, _ := bc.db.Has(blockBodyKey(hash, number)) return ok } +// HasState checks if state trie is fully present in the database or not. +func (bc *BlockChain) HasState(hash common.Hash) bool { + _, err := bc.stateCache.OpenTrie(hash) + return err == nil +} + // HasBlockAndState checks if a block and associated state trie is fully present // in the database or not, caching it if present. -func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool { +func (bc *BlockChain) HasBlockAndState(hash common.Hash, number uint64) bool { // Check first that the block itself is known - block := bc.GetBlockByHash(hash) + block := bc.GetBlock(hash, number) if block == nil { return false } - // Ensure the associated state is also present - _, err := bc.stateCache.OpenTrie(block.Root()) - return err == nil + return bc.HasState(block.Root()) } // GetBlock retrieves a block from the database by hash and number, @@ -526,7 +572,7 @@ func (bc *BlockChain) GetBlock(hash common.Hash, number uint64) *types.Block { if block, ok := bc.blockCache.Get(hash); ok { return block.(*types.Block) } - block := GetBlock(bc.chainDb, hash, number) + block := GetBlock(bc.db, hash, number) if block == nil { return nil } @@ -543,13 +589,18 @@ func (bc *BlockChain) GetBlockByHash(hash common.Hash) *types.Block { // GetBlockByNumber retrieves a block from the database by number, caching it // (associated with its hash) if found. func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block { - hash := GetCanonicalHash(bc.chainDb, number) + hash := GetCanonicalHash(bc.db, number) if hash == (common.Hash{}) { return nil } return bc.GetBlock(hash, number) } +// GetReceiptsByHash retrieves the receipts for all transactions in a given block. +func (bc *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts { + return GetBlockReceipts(bc.db, hash, GetBlockNumber(bc.db, hash)) +} + // GetBlocksFromHash returns the block corresponding to hash and up to n-1 ancestors. // [deprecated by eth/62] func (bc *BlockChain) GetBlocksFromHash(hash common.Hash, n int) (blocks []*types.Block) { @@ -577,6 +628,12 @@ func (bc *BlockChain) GetUnclesInChain(block *types.Block, length int) []*types. return uncles } +// TrieNode retrieves a blob of data associated with a trie node (or code hash) +// either from ephemeral in-memory cache, or from persistent storage. +func (bc *BlockChain) TrieNode(hash common.Hash) ([]byte, error) { + return bc.stateCache.TrieDB().Node(hash) +} + // Stop stops the blockchain service. If any imports are currently in progress // it will abort them using the procInterrupt. func (bc *BlockChain) Stop() { @@ -589,6 +646,33 @@ func (bc *BlockChain) Stop() { atomic.StoreInt32(&bc.procInterrupt, 1) bc.wg.Wait() + + // Ensure the state of a recent block is also stored to disk before exiting. + // It is fine if this state does not exist (fast start/stop cycle), but it is + // advisable to leave an N block gap from the head so 1) a restart loads up + // the last N blocks as sync assistance to remote nodes; 2) a restart during + // a (small) reorg doesn't require deep reprocesses; 3) chain "repair" from + // missing states are constantly tested. + // + // This may be tuned a bit on mainnet if its too annoying to reprocess the last + // N blocks. + if !bc.cacheConfig.Disabled { + triedb := bc.stateCache.TrieDB() + if number := bc.CurrentBlock().NumberU64(); number >= triesInMemory { + recent := bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - triesInMemory + 1) + + log.Info("Writing cached state to disk", "block", recent.Number(), "hash", recent.Hash(), "root", recent.Root()) + if err := triedb.Commit(recent.Root(), true); err != nil { + log.Error("Failed to commit recent state trie", "err", err) + } + } + for !bc.triegc.Empty() { + triedb.Dereference(bc.triegc.PopItem().(common.Hash), common.Hash{}) + } + if size := triedb.Size(); size != 0 { + log.Error("Dangling trie nodes after full cleanup") + } + } log.Info("Blockchain manager stopped") } @@ -633,11 +717,11 @@ func (bc *BlockChain) Rollback(chain []common.Hash) { } if bc.currentFastBlock.Hash() == hash { bc.currentFastBlock = bc.GetBlock(bc.currentFastBlock.ParentHash(), bc.currentFastBlock.NumberU64()-1) - WriteHeadFastBlockHash(bc.chainDb, bc.currentFastBlock.Hash()) + WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash()) } if bc.currentBlock.Hash() == hash { bc.currentBlock = bc.GetBlock(bc.currentBlock.ParentHash(), bc.currentBlock.NumberU64()-1) - WriteHeadBlockHash(bc.chainDb, bc.currentBlock.Hash()) + WriteHeadBlockHash(bc.db, bc.currentBlock.Hash()) } } } @@ -696,7 +780,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [ stats = struct{ processed, ignored int32 }{} start = time.Now() bytes = 0 - batch = bc.chainDb.NewBatch() + batch = bc.db.NewBatch() ) for i, block := range blockChain { receipts := receiptChain[i] @@ -714,7 +798,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [ continue } // Compute all the non-consensus fields of the receipts - SetReceiptsData(bc.config, block, receipts) + SetReceiptsData(bc.chainConfig, block, receipts) // Write all the data out into the database if err := WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil { return i, fmt.Errorf("failed to write block body: %v", err) @@ -747,7 +831,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [ head := blockChain[len(blockChain)-1] if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case if bc.GetTd(bc.currentFastBlock.Hash(), bc.currentFastBlock.NumberU64()).Cmp(td) < 0 { - if err := WriteHeadFastBlockHash(bc.chainDb, head.Hash()); err != nil { + if err := WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil { log.Crit("Failed to update head fast block hash", "err", err) } bc.currentFastBlock = head @@ -758,15 +842,33 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [ log.Info("Imported new block receipts", "count", stats.processed, "elapsed", common.PrettyDuration(time.Since(start)), - "bytes", bytes, "number", head.Number(), "hash", head.Hash(), + "size", common.StorageSize(bytes), "ignored", stats.ignored) return 0, nil } -// WriteBlock writes the block to the chain. -func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) { +var lastWrite uint64 + +// WriteBlockWithoutState writes only the block and its metadata to the database, +// but does not write any state. This is used to construct competing side forks +// up to the point where they exceed the canonical total difficulty. +func (bc *BlockChain) WriteBlockWithoutState(block *types.Block, td *big.Int) (err error) { + bc.wg.Add(1) + defer bc.wg.Done() + + if err := bc.hc.WriteTd(block.Hash(), block.NumberU64(), td); err != nil { + return err + } + if err := WriteBlock(bc.db, block); err != nil { + return err + } + return nil +} + +// WriteBlockWithState writes the block and all associated state to the database. +func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) { bc.wg.Add(1) defer bc.wg.Done() @@ -787,17 +889,73 @@ func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.R return NonStatTy, err } // Write other block data using a batch. - batch := bc.chainDb.NewBatch() + batch := bc.db.NewBatch() if err := WriteBlock(batch, block); err != nil { return NonStatTy, err } - if _, err := state.CommitTo(batch, bc.config.IsEIP158(block.Number())); err != nil { + root, err := state.Commit(bc.chainConfig.IsEIP158(block.Number())) + if err != nil { return NonStatTy, err } + triedb := bc.stateCache.TrieDB() + + // If we're running an archive node, always flush + if bc.cacheConfig.Disabled { + if err := triedb.Commit(root, false); err != nil { + return NonStatTy, err + } + } else { + // Full but not archive node, do proper garbage collection + triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive + bc.triegc.Push(root, -float32(block.NumberU64())) + + if current := block.NumberU64(); current > triesInMemory { + // Find the next state trie we need to commit + header := bc.GetHeaderByNumber(current - triesInMemory) + chosen := header.Number.Uint64() + + // Only write to disk if we exceeded our memory allowance *and* also have at + // least a given number of tries gapped. + var ( + size = triedb.Size() + limit = common.StorageSize(bc.cacheConfig.TrieNodeLimit) * 1024 * 1024 + ) + if size > limit || bc.gcproc > bc.cacheConfig.TrieTimeLimit { + // If we're exceeding limits but haven't reached a large enough memory gap, + // warn the user that the system is becoming unstable. + if chosen < lastWrite+triesInMemory { + switch { + case size >= 2*limit: + log.Error("Trie memory critical, forcing to disk", "size", size, "limit", limit, "optimum", float64(chosen-lastWrite)/triesInMemory) + case bc.gcproc >= 2*bc.cacheConfig.TrieTimeLimit: + log.Error("Trie timing critical, forcing to disk", "time", bc.gcproc, "allowance", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/triesInMemory) + case size > limit: + log.Warn("Trie memory at dangerous levels", "size", size, "limit", limit, "optimum", float64(chosen-lastWrite)/triesInMemory) + case bc.gcproc > bc.cacheConfig.TrieTimeLimit: + log.Warn("Trie timing at dangerous levels", "time", bc.gcproc, "limit", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/triesInMemory) + } + } + // If optimum or critical limits reached, write to disk + if chosen >= lastWrite+triesInMemory || size >= 2*limit || bc.gcproc >= 2*bc.cacheConfig.TrieTimeLimit { + triedb.Commit(header.Root, true) + lastWrite = chosen + bc.gcproc = 0 + } + } + // Garbage collect anything below our required write retention + for !bc.triegc.Empty() { + root, number := bc.triegc.Pop() + if uint64(-number) > chosen { + bc.triegc.Push(root, number) + break + } + triedb.Dereference(root.(common.Hash), common.Hash{}) + } + } + } if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil { return NonStatTy, err } - // If the total difficulty is higher than our known, add it to the canonical chain // Second clause in the if statement reduces the vulnerability to selfish mining. // Please refer to http://www.cs.cornell.edu/~ie53/publications/btcProcFC.pdf @@ -818,7 +976,7 @@ func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.R return NonStatTy, err } // Write hash preimages - if err := WritePreimages(bc.chainDb, block.NumberU64(), state.Preimages()); err != nil { + if err := WritePreimages(bc.db, block.NumberU64(), state.Preimages()); err != nil { return NonStatTy, err } status = CanonStatTy @@ -910,31 +1068,60 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty if err == nil { err = bc.Validator().ValidateBody(block) } - if err != nil { - if err == ErrKnownBlock { - stats.ignored++ - continue - } + switch { + case err == ErrKnownBlock: + stats.ignored++ + continue - if err == consensus.ErrFutureBlock { - // Allow up to MaxFuture second in the future blocks. If this limit - // is exceeded the chain is discarded and processed at a later time - // if given. - max := big.NewInt(time.Now().Unix() + maxTimeFutureBlocks) - if block.Time().Cmp(max) > 0 { - return i, events, coalescedLogs, fmt.Errorf("future block: %v > %v", block.Time(), max) + case err == consensus.ErrFutureBlock: + // Allow up to MaxFuture second in the future blocks. If this limit is exceeded + // the chain is discarded and processed at a later time if given. + max := big.NewInt(time.Now().Unix() + maxTimeFutureBlocks) + if block.Time().Cmp(max) > 0 { + return i, events, coalescedLogs, fmt.Errorf("future block: %v > %v", block.Time(), max) + } + bc.futureBlocks.Add(block.Hash(), block) + stats.queued++ + continue + + case err == consensus.ErrUnknownAncestor && bc.futureBlocks.Contains(block.ParentHash()): + bc.futureBlocks.Add(block.Hash(), block) + stats.queued++ + continue + + case err == consensus.ErrPrunedAncestor: + // Block competing with the canonical chain, store in the db, but don't process + // until the competitor TD goes above the canonical TD + localTd := bc.GetTd(bc.currentBlock.Hash(), bc.currentBlock.NumberU64()) + externTd := new(big.Int).Add(bc.GetTd(block.ParentHash(), block.NumberU64()-1), block.Difficulty()) + if localTd.Cmp(externTd) > 0 { + if err = bc.WriteBlockWithoutState(block, externTd); err != nil { + return i, events, coalescedLogs, err } - bc.futureBlocks.Add(block.Hash(), block) - stats.queued++ continue } + // Competitor chain beat canonical, gather all blocks from the common ancestor + var winner []*types.Block - if err == consensus.ErrUnknownAncestor && bc.futureBlocks.Contains(block.ParentHash()) { - bc.futureBlocks.Add(block.Hash(), block) - stats.queued++ - continue + parent := bc.GetBlock(block.ParentHash(), block.NumberU64()-1) + for !bc.HasState(parent.Root()) { + winner = append(winner, parent) + parent = bc.GetBlock(parent.ParentHash(), parent.NumberU64()-1) + } + for j := 0; j < len(winner)/2; j++ { + winner[j], winner[len(winner)-1-j] = winner[len(winner)-1-j], winner[j] + } + // Import all the pruned blocks to make the state available + bc.chainmu.Unlock() + _, evs, logs, err := bc.insertChain(winner) + bc.chainmu.Lock() + events, coalescedLogs = evs, logs + + if err != nil { + return i, events, coalescedLogs, err } + case err != nil: bc.reportBlock(block, nil, err) return i, events, coalescedLogs, err } @@ -962,8 +1149,10 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty bc.reportBlock(block, receipts, err) return i, events, coalescedLogs, err } + proctime := time.Since(bstart) + // Write the block to the chain and get the status. - status, err := bc.WriteBlockAndState(block, receipts, state) + status, err := bc.WriteBlockWithState(block, receipts, state) if err != nil { return i, events, coalescedLogs, err } @@ -977,6 +1166,9 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty events = append(events, ChainEvent{block, block.Hash(), logs}) lastCanon = block + // Only count canonical blocks for GC processing time + bc.gcproc += proctime + case SideStatTy: log.Debug("Inserted forked block", "number", block.Number(), "hash", block.Hash(), "diff", block.Difficulty(), "elapsed", common.PrettyDuration(time.Since(bstart)), "txs", len(block.Transactions()), "gas", block.GasUsed(), "uncles", len(block.Uncles())) @@ -986,7 +1178,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty } stats.processed++ stats.usedGas += usedGas - stats.report(chain, i) + stats.report(chain, i, bc.stateCache.TrieDB().Size()) } // Append a single chain head event if we've progressed the chain if lastCanon != nil && bc.CurrentBlock().Hash() == lastCanon.Hash() { @@ -1009,7 +1201,7 @@ const statsReportLimit = 8 * time.Second // report prints statistics if some number of blocks have been processed // or more than a few seconds have passed since the last message. -func (st *insertStats) report(chain []*types.Block, index int) { +func (st *insertStats) report(chain []*types.Block, index int, cache common.StorageSize) { // Fetch the timings for the batch var ( now = mclock.Now() @@ -1024,7 +1216,7 @@ func (st *insertStats) report(chain []*types.Block, index int) { context := []interface{}{ "blocks", st.processed, "txs", txs, "mgas", float64(st.usedGas) / 1000000, "elapsed", common.PrettyDuration(elapsed), "mgasps", float64(st.usedGas) * 1000 / float64(elapsed), - "number", end.Number(), "hash", end.Hash(), + "number", end.Number(), "hash", end.Hash(), "cache", cache, } if st.queued > 0 { context = append(context, []interface{}{"queued", st.queued}...) @@ -1060,7 +1252,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { // These logs are later announced as deleted. collectLogs = func(h common.Hash) { // Coalesce logs and set 'Removed'. - receipts := GetBlockReceipts(bc.chainDb, h, bc.hc.GetBlockNumber(h)) + receipts := GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h)) for _, receipt := range receipts { for _, log := range receipt.Logs { del := *log @@ -1129,7 +1321,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { // insert the block in the canonical way, re-writing history bc.insert(newChain[i]) // write lookup entries for hash based transaction/receipt searches - if err := WriteTxLookupEntries(bc.chainDb, newChain[i]); err != nil { + if err := WriteTxLookupEntries(bc.db, newChain[i]); err != nil { return err } addedTxs = append(addedTxs, newChain[i].Transactions()...) @@ -1139,7 +1331,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { // When transactions get deleted from the database that means the // receipts that were created in the fork must also be deleted for _, tx := range diff { - DeleteTxLookupEntry(bc.chainDb, tx.Hash()) + DeleteTxLookupEntry(bc.db, tx.Hash()) } if len(deletedLogs) > 0 { go bc.rmLogsFeed.Send(RemovedLogsEvent{deletedLogs}) @@ -1231,7 +1423,7 @@ Hash: 0x%x Error: %v ############################## -`, bc.config, block.Number(), block.Hash(), receiptString, err)) +`, bc.chainConfig, block.Number(), block.Hash(), receiptString, err)) } // InsertHeaderChain attempts to insert the given header chain in to the local @@ -1338,7 +1530,7 @@ func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header { } // Config retrieves the blockchain's chain configuration. -func (bc *BlockChain) Config() *params.ChainConfig { return bc.config } +func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig } // Engine retrieves the blockchain's consensus engine. func (bc *BlockChain) Engine() consensus.Engine { return bc.engine } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index cbde3bcd2..635379161 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -46,7 +46,7 @@ func newTestBlockChain(fake bool) *BlockChain { if !fake { engine = ethash.NewTester() } - blockchain, err := NewBlockChain(db, gspec.Config, engine, vm.Config{}) + blockchain, err := NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}) if err != nil { panic(err) } @@ -148,9 +148,9 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error { return err } blockchain.mu.Lock() - WriteTd(blockchain.chainDb, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) - WriteBlock(blockchain.chainDb, block) - statedb.CommitTo(blockchain.chainDb, false) + WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) + WriteBlock(blockchain.db, block) + statedb.Commit(false) blockchain.mu.Unlock() } return nil @@ -166,8 +166,8 @@ func testHeaderChainImport(chain []*types.Header, blockchain *BlockChain) error } // Manually insert the header into the database, but don't reorganise (allows subsequent testing) blockchain.mu.Lock() - WriteTd(blockchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash))) - WriteHeader(blockchain.chainDb, header) + WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash))) + WriteHeader(blockchain.db, header) blockchain.mu.Unlock() } return nil @@ -186,9 +186,9 @@ func TestLastBlock(t *testing.T) { bchain := newTestBlockChain(false) defer bchain.Stop() - block := makeBlockChain(bchain.CurrentBlock(), 1, ethash.NewFaker(), bchain.chainDb, 0)[0] + block := makeBlockChain(bchain.CurrentBlock(), 1, ethash.NewFaker(), bchain.db, 0)[0] bchain.insert(block) - if block.Hash() != GetHeadBlockHash(bchain.chainDb) { + if block.Hash() != GetHeadBlockHash(bchain.db) { t.Errorf("Write/Get HeadBlockHash failed") } } @@ -496,7 +496,7 @@ func testReorgBadHashes(t *testing.T, full bool) { } // Create a new BlockChain and check that it rolled back the state. - ncm, err := NewBlockChain(bc.chainDb, bc.config, ethash.NewFaker(), vm.Config{}) + ncm, err := NewBlockChain(bc.db, nil, bc.chainConfig, ethash.NewFaker(), vm.Config{}) if err != nil { t.Fatalf("failed to create new chain manager: %v", err) } @@ -609,7 +609,7 @@ func TestFastVsFullChains(t *testing.T) { // Import the chain as an archive node for the comparison baseline archiveDb, _ := ethdb.NewMemDatabase() gspec.MustCommit(archiveDb) - archive, _ := NewBlockChain(archiveDb, gspec.Config, ethash.NewFaker(), vm.Config{}) + archive, _ := NewBlockChain(archiveDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) defer archive.Stop() if n, err := archive.InsertChain(blocks); err != nil { @@ -618,7 +618,7 @@ func TestFastVsFullChains(t *testing.T) { // Fast import the chain as a non-archive node to test fastDb, _ := ethdb.NewMemDatabase() gspec.MustCommit(fastDb) - fast, _ := NewBlockChain(fastDb, gspec.Config, ethash.NewFaker(), vm.Config{}) + fast, _ := NewBlockChain(fastDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) defer fast.Stop() headers := make([]*types.Header, len(blocks)) @@ -696,7 +696,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) { archiveDb, _ := ethdb.NewMemDatabase() gspec.MustCommit(archiveDb) - archive, _ := NewBlockChain(archiveDb, gspec.Config, ethash.NewFaker(), vm.Config{}) + archive, _ := NewBlockChain(archiveDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) if n, err := archive.InsertChain(blocks); err != nil { t.Fatalf("failed to process block %d: %v", n, err) } @@ -709,7 +709,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) { // Import the chain as a non-archive node and ensure all pointers are updated fastDb, _ := ethdb.NewMemDatabase() gspec.MustCommit(fastDb) - fast, _ := NewBlockChain(fastDb, gspec.Config, ethash.NewFaker(), vm.Config{}) + fast, _ := NewBlockChain(fastDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) defer fast.Stop() headers := make([]*types.Header, len(blocks)) @@ -730,7 +730,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) { lightDb, _ := ethdb.NewMemDatabase() gspec.MustCommit(lightDb) - light, _ := NewBlockChain(lightDb, gspec.Config, ethash.NewFaker(), vm.Config{}) + light, _ := NewBlockChain(lightDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) if n, err := light.InsertHeaderChain(headers, 1); err != nil { t.Fatalf("failed to insert header %d: %v", n, err) } @@ -799,7 +799,7 @@ func TestChainTxReorgs(t *testing.T) { } }) // Import the chain. This runs all block validation rules. - blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) + blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) if i, err := blockchain.InsertChain(chain); err != nil { t.Fatalf("failed to insert original chain[%d]: %v", i, err) } @@ -870,7 +870,7 @@ func TestLogReorgs(t *testing.T) { signer = types.NewEIP155Signer(gspec.Config.ChainId) ) - blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) + blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) defer blockchain.Stop() rmLogsCh := make(chan RemovedLogsEvent) @@ -917,7 +917,7 @@ func TestReorgSideEvent(t *testing.T) { signer = types.NewEIP155Signer(gspec.Config.ChainId) ) - blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) + blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) defer blockchain.Stop() chain, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 3, func(i int, gen *BlockGen) {}) @@ -992,7 +992,7 @@ func TestCanonicalBlockRetrieval(t *testing.T) { bc := newTestBlockChain(true) defer bc.Stop() - chain, _ := GenerateChain(bc.config, bc.genesisBlock, ethash.NewFaker(), bc.chainDb, 10, func(i int, gen *BlockGen) {}) + chain, _ := GenerateChain(bc.chainConfig, bc.genesisBlock, ethash.NewFaker(), bc.db, 10, func(i int, gen *BlockGen) {}) var pend sync.WaitGroup pend.Add(len(chain)) @@ -1003,14 +1003,14 @@ func TestCanonicalBlockRetrieval(t *testing.T) { // try to retrieve a block by its canonical hash and see if the block data can be retrieved. for { - ch := GetCanonicalHash(bc.chainDb, block.NumberU64()) + ch := GetCanonicalHash(bc.db, block.NumberU64()) if ch == (common.Hash{}) { continue // busy wait for canonical hash to be written } if ch != block.Hash() { t.Fatalf("unknown canonical hash, want %s, got %s", block.Hash().Hex(), ch.Hex()) } - fb := GetBlock(bc.chainDb, ch, block.NumberU64()) + fb := GetBlock(bc.db, ch, block.NumberU64()) if fb == nil { t.Fatalf("unable to retrieve block %d for canonical hash: %s", block.NumberU64(), ch.Hex()) } @@ -1043,7 +1043,7 @@ func TestEIP155Transition(t *testing.T) { genesis = gspec.MustCommit(db) ) - blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) + blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) defer blockchain.Stop() blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 4, func(i int, block *BlockGen) { @@ -1151,7 +1151,7 @@ func TestEIP161AccountRemoval(t *testing.T) { } genesis = gspec.MustCommit(db) ) - blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) + blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) defer blockchain.Stop() blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 3, func(i int, block *BlockGen) { @@ -1226,7 +1226,7 @@ func TestBlockchainHeaderchainReorgConsistency(t *testing.T) { diskdb, _ := ethdb.NewMemDatabase() new(Genesis).MustCommit(diskdb) - chain, err := NewBlockChain(diskdb, params.TestChainConfig, engine, vm.Config{}) + chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{}) if err != nil { t.Fatalf("failed to create tester chain: %v", err) } @@ -1245,3 +1245,102 @@ func TestBlockchainHeaderchainReorgConsistency(t *testing.T) { } } } + +// Tests that importing small side forks doesn't leave junk in the trie database +// cache (which would eventually cause memory issues). +func TestTrieForkGC(t *testing.T) { + // Generate a canonical chain to act as the main dataset + engine := ethash.NewFaker() + + db, _ := ethdb.NewMemDatabase() + genesis := new(Genesis).MustCommit(db) + blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 2*triesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) }) + + // Generate a bunch of fork blocks, each side forking from the canonical chain + forks := make([]*types.Block, len(blocks)) + for i := 0; i < len(forks); i++ { + parent := genesis + if i > 0 { + parent = blocks[i-1] + } + fork, _ := GenerateChain(params.TestChainConfig, parent, engine, db, 1, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) }) + forks[i] = fork[0] + } + // Import the canonical and fork chain side by side, forcing the trie cache to cache both + diskdb, _ := ethdb.NewMemDatabase() + new(Genesis).MustCommit(diskdb) + + chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{}) + if err != nil { + t.Fatalf("failed to create tester chain: %v", err) + } + for i := 0; i < len(blocks); i++ { + if _, err := chain.InsertChain(blocks[i : i+1]); err != nil { + t.Fatalf("block %d: failed to insert into chain: %v", i, err) + } + if _, err := chain.InsertChain(forks[i : i+1]); err != nil { + t.Fatalf("fork %d: failed to insert into chain: %v", i, err) + } + } + // Dereference all the recent tries and ensure no past trie is left in + for i := 0; i < triesInMemory; i++ { + chain.stateCache.TrieDB().Dereference(blocks[len(blocks)-1-i].Root(), common.Hash{}) + chain.stateCache.TrieDB().Dereference(forks[len(blocks)-1-i].Root(), common.Hash{}) + } + if len(chain.stateCache.TrieDB().Nodes()) > 0 { + t.Fatalf("stale tries still alive after garbase collection") + } +} + +// Tests that doing large reorgs works even if the state associated with the +// forking point is not available any more. +func TestLargeReorgTrieGC(t *testing.T) { + // Generate the original common chain segment and the two competing forks + engine := ethash.NewFaker() + + db, _ := ethdb.NewMemDatabase() + genesis := new(Genesis).MustCommit(db) + + shared, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 64, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) }) + original, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*triesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) }) + competitor, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*triesInMemory+1, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{3}) }) + + // Import the shared chain and the original canonical one + diskdb, _ := ethdb.NewMemDatabase() + new(Genesis).MustCommit(diskdb) + + chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{}) + if err != nil { + t.Fatalf("failed to create tester chain: %v", err) + } + if _, err := chain.InsertChain(shared); err != nil { + t.Fatalf("failed to insert shared chain: %v", err) + } + if _, err := chain.InsertChain(original); err != nil { + t.Fatalf("failed to insert shared chain: %v", err) + } + // Ensure that the state associated with the forking point is pruned away + if node, _ := chain.stateCache.TrieDB().Node(shared[len(shared)-1].Root()); node != nil { + t.Fatalf("common-but-old ancestor still cache") + } + // Import the competitor chain without exceeding the canonical's TD and ensure + // we have not processed any of the blocks (protection against malicious blocks) + if _, err := chain.InsertChain(competitor[:len(competitor)-2]); err != nil { + t.Fatalf("failed to insert competitor chain: %v", err) + } + for i, block := range competitor[:len(competitor)-2] { + if node, _ := chain.stateCache.TrieDB().Node(block.Root()); node != nil { + t.Fatalf("competitor %d: low TD chain became processed", i) + } + } + // Import the head of the competitor chain, triggering the reorg and ensure we + // successfully reprocess all the stashed away blocks. + if _, err := chain.InsertChain(competitor[len(competitor)-2:]); err != nil { + t.Fatalf("failed to finalize competitor chain: %v", err) + } + for i, block := range competitor[:len(competitor)-triesInMemory] { + if node, _ := chain.stateCache.TrieDB().Node(block.Root()); node != nil { + t.Fatalf("competitor %d: competing chain state missing", i) + } + } +} diff --git a/core/chain_indexer.go b/core/chain_indexer.go index 7fb184aaa..158ed8324 100644 --- a/core/chain_indexer.go +++ b/core/chain_indexer.go @@ -203,6 +203,9 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, events chan ChainE if header.ParentHash != prevHash { // Reorg to the common ancestor (might not exist in light sync mode, skip reorg then) // TODO(karalabe, zsfelfoldi): This seems a bit brittle, can we detect this case explicitly? + + // TODO(karalabe): This operation is expensive and might block, causing the event system to + // potentially also lock up. We need to do with on a different thread somehow. if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil { c.newHead(h.Number.Uint64(), true) } diff --git a/core/chain_makers.go b/core/chain_makers.go index 5e264a994..6744428ff 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -166,7 +166,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse genblock := func(i int, parent *types.Block, statedb *state.StateDB) (*types.Block, types.Receipts) { // TODO(karalabe): This is needed for clique, which depends on multiple blocks. // It's nonetheless ugly to spin up a blockchain here. Get rid of this somehow. - blockchain, _ := NewBlockChain(db, config, engine, vm.Config{}) + blockchain, _ := NewBlockChain(db, nil, config, engine, vm.Config{}) defer blockchain.Stop() b := &BlockGen{i: i, parent: parent, chain: blocks, chainReader: blockchain, statedb: statedb, config: config, engine: engine} @@ -192,10 +192,13 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse if b.engine != nil { block, _ := b.engine.Finalize(b.chainReader, b.header, statedb, b.txs, b.uncles, b.receipts) // Write state changes to db - _, err := statedb.CommitTo(db, config.IsEIP158(b.header.Number)) + root, err := statedb.Commit(config.IsEIP158(b.header.Number)) if err != nil { panic(fmt.Sprintf("state write error: %v", err)) } + if err := statedb.Database().TrieDB().Commit(root, false); err != nil { + panic(fmt.Sprintf("trie write error: %v", err)) + } return block, b.receipts } return nil, nil @@ -246,7 +249,7 @@ func newCanonical(engine consensus.Engine, n int, full bool) (ethdb.Database, *B db, _ := ethdb.NewMemDatabase() genesis := gspec.MustCommit(db) - blockchain, _ := NewBlockChain(db, params.AllEthashProtocolChanges, engine, vm.Config{}) + blockchain, _ := NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{}) // Create and inject the requested chain if n == 0 { return db, blockchain, nil diff --git a/core/chain_makers_test.go b/core/chain_makers_test.go index a3b80da29..93be43ddc 100644 --- a/core/chain_makers_test.go +++ b/core/chain_makers_test.go @@ -79,7 +79,7 @@ func ExampleGenerateChain() { }) // Import the chain. This runs all block validation rules. - blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{}) + blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) defer blockchain.Stop() if i, err := blockchain.InsertChain(chain); err != nil { diff --git a/core/dao_test.go b/core/dao_test.go index 43e2982a5..e0a3e3ff3 100644 --- a/core/dao_test.go +++ b/core/dao_test.go @@ -45,7 +45,7 @@ func TestDAOForkRangeExtradata(t *testing.T) { proConf.DAOForkBlock = forkBlock proConf.DAOForkSupport = true - proBc, _ := NewBlockChain(proDb, &proConf, ethash.NewFaker(), vm.Config{}) + proBc, _ := NewBlockChain(proDb, nil, &proConf, ethash.NewFaker(), vm.Config{}) defer proBc.Stop() conDb, _ := ethdb.NewMemDatabase() @@ -55,7 +55,7 @@ func TestDAOForkRangeExtradata(t *testing.T) { conConf.DAOForkBlock = forkBlock conConf.DAOForkSupport = false - conBc, _ := NewBlockChain(conDb, &conConf, ethash.NewFaker(), vm.Config{}) + conBc, _ := NewBlockChain(conDb, nil, &conConf, ethash.NewFaker(), vm.Config{}) defer conBc.Stop() if _, err := proBc.InsertChain(prefix); err != nil { @@ -69,7 +69,7 @@ func TestDAOForkRangeExtradata(t *testing.T) { // Create a pro-fork block, and try to feed into the no-fork chain db, _ = ethdb.NewMemDatabase() gspec.MustCommit(db) - bc, _ := NewBlockChain(db, &conConf, ethash.NewFaker(), vm.Config{}) + bc, _ := NewBlockChain(db, nil, &conConf, ethash.NewFaker(), vm.Config{}) defer bc.Stop() blocks := conBc.GetBlocksFromHash(conBc.CurrentBlock().Hash(), int(conBc.CurrentBlock().NumberU64())) @@ -79,6 +79,9 @@ func TestDAOForkRangeExtradata(t *testing.T) { if _, err := bc.InsertChain(blocks); err != nil { t.Fatalf("failed to import contra-fork chain for expansion: %v", err) } + if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil { + t.Fatalf("failed to commit contra-fork head for expansion: %v", err) + } blocks, _ = GenerateChain(&proConf, conBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {}) if _, err := conBc.InsertChain(blocks); err == nil { t.Fatalf("contra-fork chain accepted pro-fork block: %v", blocks[0]) @@ -91,7 +94,7 @@ func TestDAOForkRangeExtradata(t *testing.T) { // Create a no-fork block, and try to feed into the pro-fork chain db, _ = ethdb.NewMemDatabase() gspec.MustCommit(db) - bc, _ = NewBlockChain(db, &proConf, ethash.NewFaker(), vm.Config{}) + bc, _ = NewBlockChain(db, nil, &proConf, ethash.NewFaker(), vm.Config{}) defer bc.Stop() blocks = proBc.GetBlocksFromHash(proBc.CurrentBlock().Hash(), int(proBc.CurrentBlock().NumberU64())) @@ -101,6 +104,9 @@ func TestDAOForkRangeExtradata(t *testing.T) { if _, err := bc.InsertChain(blocks); err != nil { t.Fatalf("failed to import pro-fork chain for expansion: %v", err) } + if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil { + t.Fatalf("failed to commit pro-fork head for expansion: %v", err) + } blocks, _ = GenerateChain(&conConf, proBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {}) if _, err := proBc.InsertChain(blocks); err == nil { t.Fatalf("pro-fork chain accepted contra-fork block: %v", blocks[0]) @@ -114,7 +120,7 @@ func TestDAOForkRangeExtradata(t *testing.T) { // Verify that contra-forkers accept pro-fork extra-datas after forking finishes db, _ = ethdb.NewMemDatabase() gspec.MustCommit(db) - bc, _ := NewBlockChain(db, &conConf, ethash.NewFaker(), vm.Config{}) + bc, _ := NewBlockChain(db, nil, &conConf, ethash.NewFaker(), vm.Config{}) defer bc.Stop() blocks := conBc.GetBlocksFromHash(conBc.CurrentBlock().Hash(), int(conBc.CurrentBlock().NumberU64())) @@ -124,6 +130,9 @@ func TestDAOForkRangeExtradata(t *testing.T) { if _, err := bc.InsertChain(blocks); err != nil { t.Fatalf("failed to import contra-fork chain for expansion: %v", err) } + if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil { + t.Fatalf("failed to commit contra-fork head for expansion: %v", err) + } blocks, _ = GenerateChain(&proConf, conBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {}) if _, err := conBc.InsertChain(blocks); err != nil { t.Fatalf("contra-fork chain didn't accept pro-fork block post-fork: %v", err) @@ -131,7 +140,7 @@ func TestDAOForkRangeExtradata(t *testing.T) { // Verify that pro-forkers accept contra-fork extra-datas after forking finishes db, _ = ethdb.NewMemDatabase() gspec.MustCommit(db) - bc, _ = NewBlockChain(db, &proConf, ethash.NewFaker(), vm.Config{}) + bc, _ = NewBlockChain(db, nil, &proConf, ethash.NewFaker(), vm.Config{}) defer bc.Stop() blocks = proBc.GetBlocksFromHash(proBc.CurrentBlock().Hash(), int(proBc.CurrentBlock().NumberU64())) @@ -141,6 +150,9 @@ func TestDAOForkRangeExtradata(t *testing.T) { if _, err := bc.InsertChain(blocks); err != nil { t.Fatalf("failed to import pro-fork chain for expansion: %v", err) } + if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil { + t.Fatalf("failed to commit pro-fork head for expansion: %v", err) + } blocks, _ = GenerateChain(&conConf, proBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {}) if _, err := proBc.InsertChain(blocks); err != nil { t.Fatalf("pro-fork chain didn't accept contra-fork block post-fork: %v", err) diff --git a/core/genesis.go b/core/genesis.go index e22985b80..b6ead2250 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -169,10 +169,9 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig // Check whether the genesis block is already written. if genesis != nil { - block, _ := genesis.ToBlock() - hash := block.Hash() + hash := genesis.ToBlock(nil).Hash() if hash != stored { - return genesis.Config, block.Hash(), &GenesisMismatchError{stored, hash} + return genesis.Config, hash, &GenesisMismatchError{stored, hash} } } @@ -220,9 +219,12 @@ func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig { } } -// ToBlock creates the block and state of a genesis specification. -func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) { - db, _ := ethdb.NewMemDatabase() +// ToBlock creates the genesis block and writes state of a genesis specification +// to the given database (or discards it if nil). +func (g *Genesis) ToBlock(db ethdb.Database) *types.Block { + if db == nil { + db, _ = ethdb.NewMemDatabase() + } statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) for addr, account := range g.Alloc { statedb.AddBalance(addr, account.Balance) @@ -252,19 +254,19 @@ func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) { if g.Difficulty == nil { head.Difficulty = params.GenesisDifficulty } - return types.NewBlock(head, nil, nil, nil), statedb + statedb.Commit(false) + statedb.Database().TrieDB().Commit(root, true) + + return types.NewBlock(head, nil, nil, nil) } // Commit writes the block and state of a genesis specification to the database. // The block is committed as the canonical head block. func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) { - block, statedb := g.ToBlock() + block := g.ToBlock(db) if block.Number().Sign() != 0 { return nil, fmt.Errorf("can't commit genesis block with number > 0") } - if _, err := statedb.CommitTo(db, false); err != nil { - return nil, fmt.Errorf("cannot write state: %v", err) - } if err := WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil { return nil, err } diff --git a/core/genesis_test.go b/core/genesis_test.go index 2fe931b24..cd548d4b1 100644 --- a/core/genesis_test.go +++ b/core/genesis_test.go @@ -30,11 +30,11 @@ import ( ) func TestDefaultGenesisBlock(t *testing.T) { - block, _ := DefaultGenesisBlock().ToBlock() + block := DefaultGenesisBlock().ToBlock(nil) if block.Hash() != params.MainnetGenesisHash { t.Errorf("wrong mainnet genesis hash, got %v, want %v", block.Hash(), params.MainnetGenesisHash) } - block, _ = DefaultTestnetGenesisBlock().ToBlock() + block = DefaultTestnetGenesisBlock().ToBlock(nil) if block.Hash() != params.TestnetGenesisHash { t.Errorf("wrong testnet genesis hash, got %v, want %v", block.Hash(), params.TestnetGenesisHash) } @@ -118,7 +118,7 @@ func TestSetupGenesis(t *testing.T) { // Commit the 'old' genesis block with Homestead transition at #2. // Advance to block #4, past the homestead transition block of customg. genesis := oldcustomg.MustCommit(db) - bc, _ := NewBlockChain(db, oldcustomg.Config, ethash.NewFullFaker(), vm.Config{}) + bc, _ := NewBlockChain(db, nil, oldcustomg.Config, ethash.NewFullFaker(), vm.Config{}) defer bc.Stop() bc.SetValidator(bproc{}) bc.InsertChain(makeBlockChainWithDiff(genesis, []int{2, 3, 4, 5}, 0)) diff --git a/core/state/database.go b/core/state/database.go index 946625e76..36926ec69 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -40,16 +40,23 @@ const ( // Database wraps access to tries and contract code. type Database interface { - // Accessing tries: // OpenTrie opens the main account trie. - // OpenStorageTrie opens the storage trie of an account. OpenTrie(root common.Hash) (Trie, error) + + // OpenStorageTrie opens the storage trie of an account. OpenStorageTrie(addrHash, root common.Hash) (Trie, error) - // Accessing contract code: - ContractCode(addrHash, codeHash common.Hash) ([]byte, error) - ContractCodeSize(addrHash, codeHash common.Hash) (int, error) + // CopyTrie returns an independent copy of the given trie. CopyTrie(Trie) Trie + + // ContractCode retrieves a particular contract's code. + ContractCode(addrHash, codeHash common.Hash) ([]byte, error) + + // ContractCodeSize retrieves a particular contracts code's size. + ContractCodeSize(addrHash, codeHash common.Hash) (int, error) + + // TrieDB retrieves the low level trie database used for data storage. + TrieDB() *trie.Database } // Trie is a Ethereum Merkle Trie. @@ -57,26 +64,33 @@ type Trie interface { TryGet(key []byte) ([]byte, error) TryUpdate(key, value []byte) error TryDelete(key []byte) error - CommitTo(trie.DatabaseWriter) (common.Hash, error) + Commit(onleaf trie.LeafCallback) (common.Hash, error) Hash() common.Hash NodeIterator(startKey []byte) trie.NodeIterator GetKey([]byte) []byte // TODO(fjl): remove this when SecureTrie is removed + Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error } // NewDatabase creates a backing store for state. The returned database is safe for -// concurrent use and retains cached trie nodes in memory. +// concurrent use and retains cached trie nodes in memory. The pool is an optional +// intermediate trie-node memory pool between the low level storage layer and the +// high level trie abstraction. func NewDatabase(db ethdb.Database) Database { csc, _ := lru.New(codeSizeCacheSize) - return &cachingDB{db: db, codeSizeCache: csc} + return &cachingDB{ + db: trie.NewDatabase(db), + codeSizeCache: csc, + } } type cachingDB struct { - db ethdb.Database + db *trie.Database mu sync.Mutex pastTries []*trie.SecureTrie codeSizeCache *lru.Cache } +// OpenTrie opens the main account trie. func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { db.mu.Lock() defer db.mu.Unlock() @@ -105,10 +119,12 @@ func (db *cachingDB) pushTrie(t *trie.SecureTrie) { } } +// OpenStorageTrie opens the storage trie of an account. func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { return trie.NewSecure(root, db.db, 0) } +// CopyTrie returns an independent copy of the given trie. func (db *cachingDB) CopyTrie(t Trie) Trie { switch t := t.(type) { case cachedTrie: @@ -120,14 +136,16 @@ func (db *cachingDB) CopyTrie(t Trie) Trie { } } +// ContractCode retrieves a particular contract's code. func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) { - code, err := db.db.Get(codeHash[:]) + code, err := db.db.Node(codeHash) if err == nil { db.codeSizeCache.Add(codeHash, len(code)) } return code, err } +// ContractCodeSize retrieves a particular contracts code's size. func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) { if cached, ok := db.codeSizeCache.Get(codeHash); ok { return cached.(int), nil @@ -139,16 +157,25 @@ func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, erro return len(code), err } +// TrieDB retrieves any intermediate trie-node caching layer. +func (db *cachingDB) TrieDB() *trie.Database { + return db.db +} + // cachedTrie inserts its trie into a cachingDB on commit. type cachedTrie struct { *trie.SecureTrie db *cachingDB } -func (m cachedTrie) CommitTo(dbw trie.DatabaseWriter) (common.Hash, error) { - root, err := m.SecureTrie.CommitTo(dbw) +func (m cachedTrie) Commit(onleaf trie.LeafCallback) (common.Hash, error) { + root, err := m.SecureTrie.Commit(onleaf) if err == nil { m.db.pushTrie(m.SecureTrie) } return root, err } + +func (m cachedTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error { + return m.SecureTrie.Prove(key, fromLevel, proofDb) +} diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go index ff66ba7a9..9e46c851c 100644 --- a/core/state/iterator_test.go +++ b/core/state/iterator_test.go @@ -21,12 +21,13 @@ import ( "testing" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" ) // Tests that the node iterator indeed walks over the entire database contents. func TestNodeIteratorCoverage(t *testing.T) { // Create some arbitrary test state to iterate - db, mem, root, _ := makeTestState() + db, root, _ := makeTestState() state, err := New(root, db) if err != nil { @@ -39,14 +40,18 @@ func TestNodeIteratorCoverage(t *testing.T) { hashes[it.Hash] = struct{}{} } } - - // Cross check the hashes and the database itself + // Cross check the iterated hashes and the database/nodepool content for hash := range hashes { - if _, err := mem.Get(hash.Bytes()); err != nil { - t.Errorf("failed to retrieve reported node %x: %v", hash, err) + if _, err := db.TrieDB().Node(hash); err != nil { + t.Errorf("failed to retrieve reported node %x", hash) } } - for _, key := range mem.Keys() { + for _, hash := range db.TrieDB().Nodes() { + if _, ok := hashes[hash]; !ok { + t.Errorf("state entry not reported %x", hash) + } + } + for _, key := range db.TrieDB().DiskDB().(*ethdb.MemDatabase).Keys() { if bytes.HasPrefix(key, []byte("secure-key-")) { continue } diff --git a/core/state/state_object.go b/core/state/state_object.go index b2378c69c..b2112bfae 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -25,7 +25,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" ) var emptyCodeHash = crypto.Keccak256(nil) @@ -238,12 +237,12 @@ func (self *stateObject) updateRoot(db Database) { // CommitTrie the storage trie of the object to dwb. // This updates the trie root. -func (self *stateObject) CommitTrie(db Database, dbw trie.DatabaseWriter) error { +func (self *stateObject) CommitTrie(db Database) error { self.updateTrie(db) if self.dbErr != nil { return self.dbErr } - root, err := self.trie.CommitTo(dbw) + root, err := self.trie.Commit(nil) if err == nil { self.data.Root = root } diff --git a/core/state/state_test.go b/core/state/state_test.go index bbae3685b..6d42d63d8 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -48,7 +48,7 @@ func (s *StateSuite) TestDump(c *checker.C) { // write some of them to the trie s.state.updateStateObject(obj1) s.state.updateStateObject(obj2) - s.state.CommitTo(s.db, false) + s.state.Commit(false) // check that dump contains the state objects that are in trie got := string(s.state.Dump()) @@ -97,7 +97,7 @@ func (s *StateSuite) TestNull(c *checker.C) { //value := common.FromHex("0x823140710bf13990e4500136726d8b55") var value common.Hash s.state.SetState(address, common.Hash{}, value) - s.state.CommitTo(s.db, false) + s.state.Commit(false) value = s.state.GetState(address, common.Hash{}) if !common.EmptyHash(value) { c.Errorf("expected empty hash. got %x", value) @@ -155,7 +155,7 @@ func TestSnapshot2(t *testing.T) { so0.deleted = false state.setStateObject(so0) - root, _ := state.CommitTo(db, false) + root, _ := state.Commit(false) state.Reset(root) // and one with deleted == true diff --git a/core/state/statedb.go b/core/state/statedb.go index 8e29104d5..776693e24 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -36,6 +36,14 @@ type revision struct { journalIndex int } +var ( + // emptyState is the known hash of an empty state trie entry. + emptyState = crypto.Keccak256Hash(nil) + + // emptyCode is the known hash of the empty EVM bytecode. + emptyCode = crypto.Keccak256Hash(nil) +) + // StateDBs within the ethereum protocol are used to store anything // within the merkle trie. StateDBs take care of caching and storing // nested states. It's the general query interface to retrieve: @@ -235,6 +243,11 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash { return common.Hash{} } +// Database retrieves the low level database supporting the lower level trie ops. +func (self *StateDB) Database() Database { + return self.db +} + // StorageTrie returns the storage trie of an account. // The return value is a copy and is nil for non-existent accounts. func (self *StateDB) StorageTrie(a common.Address) Trie { @@ -568,8 +581,8 @@ func (s *StateDB) clearJournalAndRefund() { s.refund = 0 } -// CommitTo writes the state to the given database. -func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (root common.Hash, err error) { +// Commit writes the state to the underlying in-memory trie database. +func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) { defer s.clearJournalAndRefund() // Commit objects to the trie. @@ -583,13 +596,11 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro case isDirty: // Write any contract code associated with the state object if stateObject.code != nil && stateObject.dirtyCode { - if err := dbw.Put(stateObject.CodeHash(), stateObject.code); err != nil { - return common.Hash{}, err - } + s.db.TrieDB().Insert(common.BytesToHash(stateObject.CodeHash()), stateObject.code) stateObject.dirtyCode = false } // Write any storage changes in the state object to its storage trie. - if err := stateObject.CommitTrie(s.db, dbw); err != nil { + if err := stateObject.CommitTrie(s.db); err != nil { return common.Hash{}, err } // Update the object in the main account trie. @@ -598,7 +609,20 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro delete(s.stateObjectsDirty, addr) } // Write trie changes. - root, err = s.trie.CommitTo(dbw) + root, err = s.trie.Commit(func(leaf []byte, parent common.Hash) error { + var account Account + if err := rlp.DecodeBytes(leaf, &account); err != nil { + return nil + } + if account.Root != emptyState { + s.db.TrieDB().Reference(account.Root, parent) + } + code := common.BytesToHash(account.CodeHash) + if code != emptyCode { + s.db.TrieDB().Reference(code, parent) + } + return nil + }) log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads()) return root, err } diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 5c80e3aa5..d9e3d9b79 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -97,10 +97,10 @@ func TestIntermediateLeaks(t *testing.T) { } // Commit and cross check the databases. - if _, err := transState.CommitTo(transDb, false); err != nil { + if _, err := transState.Commit(false); err != nil { t.Fatalf("failed to commit transition state: %v", err) } - if _, err := finalState.CommitTo(finalDb, false); err != nil { + if _, err := finalState.Commit(false); err != nil { t.Fatalf("failed to commit final state: %v", err) } for _, key := range finalDb.Keys() { @@ -122,8 +122,8 @@ func TestIntermediateLeaks(t *testing.T) { // https://github.com/ethereum/go-ethereum/pull/15549. func TestCopy(t *testing.T) { // Create a random state test to copy and modify "independently" - mem, _ := ethdb.NewMemDatabase() - orig, _ := New(common.Hash{}, NewDatabase(mem)) + db, _ := ethdb.NewMemDatabase() + orig, _ := New(common.Hash{}, NewDatabase(db)) for i := byte(0); i < 255; i++ { obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) @@ -346,11 +346,10 @@ func (test *snapshotTest) run() bool { } action.fn(action, state) } - // Revert all snapshots in reverse order. Each revert must yield a state // that is equivalent to fresh state with all actions up the snapshot applied. for sindex--; sindex >= 0; sindex-- { - checkstate, _ := New(common.Hash{}, NewDatabase(db)) + checkstate, _ := New(common.Hash{}, state.Database()) for _, action := range test.actions[:test.snapshots[sindex]] { action.fn(action, checkstate) } @@ -409,7 +408,7 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { func (s *StateSuite) TestTouchDelete(c *check.C) { s.state.GetOrNewStateObject(common.Address{}) - root, _ := s.state.CommitTo(s.db, false) + root, _ := s.state.Commit(false) s.state.Reset(root) snapshot := s.state.Snapshot() @@ -417,7 +416,6 @@ func (s *StateSuite) TestTouchDelete(c *check.C) { if len(s.state.stateObjectsDirty) != 1 { c.Fatal("expected one dirty state object") } - s.state.RevertToSnapshot(snapshot) if len(s.state.stateObjectsDirty) != 0 { c.Fatal("expected no dirty state object") diff --git a/core/state/sync_test.go b/core/state/sync_test.go index 06c572ea6..8f14a44e7 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -36,10 +36,10 @@ type testAccount struct { } // makeTestState create a sample test state to test node-wise reconstruction. -func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount) { +func makeTestState() (Database, common.Hash, []*testAccount) { // Create an empty state - mem, _ := ethdb.NewMemDatabase() - db := NewDatabase(mem) + diskdb, _ := ethdb.NewMemDatabase() + db := NewDatabase(diskdb) state, _ := New(common.Hash{}, db) // Fill it with some arbitrary data @@ -61,10 +61,10 @@ func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount) state.updateStateObject(obj) accounts = append(accounts, acc) } - root, _ := state.CommitTo(mem, false) + root, _ := state.Commit(false) // Return the generated state - return db, mem, root, accounts + return db, root, accounts } // checkStateAccounts cross references a reconstructed state with an expected @@ -96,7 +96,7 @@ func checkTrieConsistency(db ethdb.Database, root common.Hash) error { if v, _ := db.Get(root[:]); v == nil { return nil // Consider a non existent state consistent. } - trie, err := trie.New(root, db) + trie, err := trie.New(root, trie.NewDatabase(db)) if err != nil { return err } @@ -138,7 +138,7 @@ func TestIterativeStateSyncBatched(t *testing.T) { testIterativeStateSync(t, func testIterativeStateSync(t *testing.T, batch int) { // Create a random state to copy - _, srcMem, srcRoot, srcAccounts := makeTestState() + srcDb, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -148,9 +148,9 @@ func testIterativeStateSync(t *testing.T, batch int) { for len(queue) > 0 { results := make([]trie.SyncResult, len(queue)) for i, hash := range queue { - data, err := srcMem.Get(hash.Bytes()) + data, err := srcDb.TrieDB().Node(hash) if err != nil { - t.Fatalf("failed to retrieve node data for %x: %v", hash, err) + t.Fatalf("failed to retrieve node data for %x", hash) } results[i] = trie.SyncResult{Hash: hash, Data: data} } @@ -170,7 +170,7 @@ func testIterativeStateSync(t *testing.T, batch int) { // partial results are returned, and the others sent only later. func TestIterativeDelayedStateSync(t *testing.T) { // Create a random state to copy - _, srcMem, srcRoot, srcAccounts := makeTestState() + srcDb, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -181,9 +181,9 @@ func TestIterativeDelayedStateSync(t *testing.T) { // Sync only half of the scheduled nodes results := make([]trie.SyncResult, len(queue)/2+1) for i, hash := range queue[:len(results)] { - data, err := srcMem.Get(hash.Bytes()) + data, err := srcDb.TrieDB().Node(hash) if err != nil { - t.Fatalf("failed to retrieve node data for %x: %v", hash, err) + t.Fatalf("failed to retrieve node data for %x", hash) } results[i] = trie.SyncResult{Hash: hash, Data: data} } @@ -207,7 +207,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS func testIterativeRandomStateSync(t *testing.T, batch int) { // Create a random state to copy - _, srcMem, srcRoot, srcAccounts := makeTestState() + srcDb, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -221,9 +221,9 @@ func testIterativeRandomStateSync(t *testing.T, batch int) { // Fetch all the queued nodes in a random order results := make([]trie.SyncResult, 0, len(queue)) for hash := range queue { - data, err := srcMem.Get(hash.Bytes()) + data, err := srcDb.TrieDB().Node(hash) if err != nil { - t.Fatalf("failed to retrieve node data for %x: %v", hash, err) + t.Fatalf("failed to retrieve node data for %x", hash) } results = append(results, trie.SyncResult{Hash: hash, Data: data}) } @@ -247,7 +247,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) { // partial results are returned (Even those randomly), others sent only later. func TestIterativeRandomDelayedStateSync(t *testing.T) { // Create a random state to copy - _, srcMem, srcRoot, srcAccounts := makeTestState() + srcDb, srcRoot, srcAccounts := makeTestState() // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -263,9 +263,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { for hash := range queue { delete(queue, hash) - data, err := srcMem.Get(hash.Bytes()) + data, err := srcDb.TrieDB().Node(hash) if err != nil { - t.Fatalf("failed to retrieve node data for %x: %v", hash, err) + t.Fatalf("failed to retrieve node data for %x", hash) } results = append(results, trie.SyncResult{Hash: hash, Data: data}) @@ -292,9 +292,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) { // the database. func TestIncompleteStateSync(t *testing.T) { // Create a random state to copy - _, srcMem, srcRoot, srcAccounts := makeTestState() + srcDb, srcRoot, srcAccounts := makeTestState() - checkTrieConsistency(srcMem, srcRoot) + checkTrieConsistency(srcDb.TrieDB().DiskDB().(ethdb.Database), srcRoot) // Create a destination state and sync with the scheduler dstDb, _ := ethdb.NewMemDatabase() @@ -306,9 +306,9 @@ func TestIncompleteStateSync(t *testing.T) { // Fetch a batch of state nodes results := make([]trie.SyncResult, len(queue)) for i, hash := range queue { - data, err := srcMem.Get(hash.Bytes()) + data, err := srcDb.TrieDB().Node(hash) if err != nil { - t.Fatalf("failed to retrieve node data for %x: %v", hash, err) + t.Fatalf("failed to retrieve node data for %x", hash) } results[i] = trie.SyncResult{Hash: hash, Data: data} } diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index cd11f2ba2..158b9776b 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -78,8 +78,8 @@ func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ec } func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { - db, _ := ethdb.NewMemDatabase() - statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + diskdb, _ := ethdb.NewMemDatabase() + statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb)) blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} key, _ := crypto.GenerateKey() diff --git a/core/types/block.go b/core/types/block.go index ffe317342..92b868d9d 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -25,6 +25,7 @@ import ( "sort" "sync/atomic" "time" + "unsafe" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" @@ -121,6 +122,12 @@ func (h *Header) HashNoNonce() common.Hash { }) } +// Size returns the approximate memory used by all internal contents. It is used +// to approximate and limit the memory consumption of various caches. +func (h *Header) Size() common.StorageSize { + return common.StorageSize(unsafe.Sizeof(*h)) + common.StorageSize(len(h.Extra)+(h.Difficulty.BitLen()+h.Number.BitLen()+h.Time.BitLen())/8) +} + func rlpHash(x interface{}) (h common.Hash) { hw := sha3.NewKeccak256() rlp.Encode(hw, x) @@ -322,6 +329,8 @@ func (b *Block) HashNoNonce() common.Hash { return b.header.HashNoNonce() } +// Size returns the true RLP encoded storage size of the block, either by encoding +// and returning it, or returning a previsouly cached value. func (b *Block) Size() common.StorageSize { if size := b.size.Load(); size != nil { return size.(common.StorageSize) diff --git a/core/types/receipt.go b/core/types/receipt.go index 208d54aaa..f945f6f6a 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -20,6 +20,7 @@ import ( "bytes" "fmt" "io" + "unsafe" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" @@ -136,6 +137,18 @@ func (r *Receipt) statusEncoding() []byte { return r.PostState } +// Size returns the approximate memory used by all internal contents. It is used +// to approximate and limit the memory consumption of various caches. +func (r *Receipt) Size() common.StorageSize { + size := common.StorageSize(unsafe.Sizeof(*r)) + common.StorageSize(len(r.PostState)) + + size += common.StorageSize(len(r.Logs)) * common.StorageSize(unsafe.Sizeof(Log{})) + for _, log := range r.Logs { + size += common.StorageSize(len(log.Topics)*common.HashLength + len(log.Data)) + } + return size +} + // String implements the Stringer interface. func (r *Receipt) String() string { if len(r.PostState) == 0 { diff --git a/core/types/transaction.go b/core/types/transaction.go index a7ed211e4..5660582ba 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -206,6 +206,8 @@ func (tx *Transaction) Hash() common.Hash { return v } +// Size returns the true RLP encoded storage size of the transaction, either by +// encoding and returning it, or returning a previsouly cached value. func (tx *Transaction) Size() common.StorageSize { if size := tx.size.Load(); size != nil { return size.(common.StorageSize) diff --git a/eth/api.go b/eth/api.go index 0db3eb554..a345b57e4 100644 --- a/eth/api.go +++ b/eth/api.go @@ -462,11 +462,11 @@ func (api *PrivateDebugAPI) getModifiedAccounts(startBlock, endBlock *types.Bloc return nil, fmt.Errorf("start block height (%d) must be less than end block height (%d)", startBlock.Number().Uint64(), endBlock.Number().Uint64()) } - oldTrie, err := trie.NewSecure(startBlock.Root(), api.eth.chainDb, 0) + oldTrie, err := trie.NewSecure(startBlock.Root(), trie.NewDatabase(api.eth.chainDb), 0) if err != nil { return nil, err } - newTrie, err := trie.NewSecure(endBlock.Root(), api.eth.chainDb, 0) + newTrie, err := trie.NewSecure(endBlock.Root(), trie.NewDatabase(api.eth.chainDb), 0) if err != nil { return nil, err } diff --git a/eth/api_tracer.go b/eth/api_tracer.go index d49f077ae..07c4457bc 100644 --- a/eth/api_tracer.go +++ b/eth/api_tracer.go @@ -24,7 +24,6 @@ import ( "io/ioutil" "runtime" "sync" - "sync/atomic" "time" "github.com/ethereum/go-ethereum/common" @@ -34,7 +33,6 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/eth/tracers" - "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" @@ -72,6 +70,7 @@ type txTraceResult struct { type blockTraceTask struct { statedb *state.StateDB // Intermediate state prepped for tracing block *types.Block // Block to trace the transactions from + rootref common.Hash // Trie root reference held for this task results []*txTraceResult // Trace results procudes by the task } @@ -90,59 +89,6 @@ type txTraceTask struct { index int // Transaction offset in the block } -// ephemeralDatabase is a memory wrapper around a proper database, which acts as -// an ephemeral write layer. This construct is used by the chain tracer to write -// state tries for intermediate blocks without serializing to disk, but at the -// same time to allow disk fallback for reads that do no hit the memory layer. -type ephemeralDatabase struct { - diskdb ethdb.Database // Persistent disk database to fall back to with reads - memdb *ethdb.MemDatabase // Ephemeral memory database for primary reads and writes -} - -func (db *ephemeralDatabase) Put(key []byte, value []byte) error { return db.memdb.Put(key, value) } -func (db *ephemeralDatabase) Delete(key []byte) error { return errors.New("delete not supported") } -func (db *ephemeralDatabase) Close() { db.memdb.Close() } -func (db *ephemeralDatabase) NewBatch() ethdb.Batch { - return db.memdb.NewBatch() -} -func (db *ephemeralDatabase) Has(key []byte) (bool, error) { - if has, _ := db.memdb.Has(key); has { - return has, nil - } - return db.diskdb.Has(key) -} -func (db *ephemeralDatabase) Get(key []byte) ([]byte, error) { - if blob, _ := db.memdb.Get(key); blob != nil { - return blob, nil - } - return db.diskdb.Get(key) -} - -// Prune does a state sync into a new memory write layer and replaces the old one. -// This allows us to discard entries that are no longer referenced from the current -// state. -func (db *ephemeralDatabase) Prune(root common.Hash) { - // Pull the still relevant state data into memory - sync := state.NewStateSync(root, db.diskdb) - for sync.Pending() > 0 { - hash := sync.Missing(1)[0] - - // Move the next trie node from the memory layer into a sync struct - node, err := db.memdb.Get(hash[:]) - if err != nil { - panic(err) // memdb must have the data - } - if _, _, err := sync.Process([]trie.SyncResult{{Hash: hash, Data: node}}); err != nil { - panic(err) // it's not possible to fail processing a node - } - } - // Discard the old memory layer and write a new one - db.memdb, _ = ethdb.NewMemDatabaseWithCap(db.memdb.Len()) - if _, err := sync.Commit(db); err != nil { - panic(err) // writing into a memdb cannot fail - } -} - // TraceChain returns the structured logs created during the execution of EVM // between two blocks (excluding start) and returns them as a JSON object. func (api *PrivateDebugAPI) TraceChain(ctx context.Context, start, end rpc.BlockNumber, config *TraceConfig) (*rpc.Subscription, error) { @@ -188,19 +134,15 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl // Ensure we have a valid starting state before doing any work origin := start.NumberU64() + database := state.NewDatabase(api.eth.ChainDb()) - memdb, _ := ethdb.NewMemDatabase() - db := &ephemeralDatabase{ - diskdb: api.eth.ChainDb(), - memdb: memdb, - } if number := start.NumberU64(); number > 0 { start = api.eth.blockchain.GetBlock(start.ParentHash(), start.NumberU64()-1) if start == nil { return nil, fmt.Errorf("parent block #%d not found", number-1) } } - statedb, err := state.New(start.Root(), state.NewDatabase(db)) + statedb, err := state.New(start.Root(), database) if err != nil { // If the starting state is missing, allow some number of blocks to be reexecuted reexec := defaultTraceReexec @@ -213,7 +155,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl if start == nil { break } - if statedb, err = state.New(start.Root(), state.NewDatabase(db)); err == nil { + if statedb, err = state.New(start.Root(), database); err == nil { break } } @@ -256,7 +198,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) if err != nil { task.results[i] = &txTraceResult{Error: err.Error()} - log.Warn("Tracing failed", "err", err) + log.Warn("Tracing failed", "hash", tx.Hash(), "block", task.block.NumberU64(), "err", err) break } task.statedb.DeleteSuicides() @@ -273,7 +215,6 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl } // Start a goroutine to feed all the blocks into the tracers begin := time.Now() - complete := start.NumberU64() go func() { var ( @@ -281,6 +222,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl number uint64 traced uint64 failed error + proot common.Hash ) // Ensure everything is properly cleaned up on any exit path defer func() { @@ -308,7 +250,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl // Print progress logs if long enough time elapsed if time.Since(logged) > 8*time.Second { if number > origin { - log.Info("Tracing chain segment", "start", origin, "end", end.NumberU64(), "current", number, "transactions", traced, "elapsed", time.Since(begin)) + log.Info("Tracing chain segment", "start", origin, "end", end.NumberU64(), "current", number, "transactions", traced, "elapsed", time.Since(begin), "memory", database.TrieDB().Size()) } else { log.Info("Preparing state for chain trace", "block", number, "start", origin, "elapsed", time.Since(begin)) } @@ -325,13 +267,11 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl txs := block.Transactions() select { - case tasks <- &blockTraceTask{statedb: statedb.Copy(), block: block, results: make([]*txTraceResult, len(txs))}: + case tasks <- &blockTraceTask{statedb: statedb.Copy(), block: block, rootref: proot, results: make([]*txTraceResult, len(txs))}: case <-notifier.Closed(): return } traced += uint64(len(txs)) - } else { - atomic.StoreUint64(&complete, number) } // Generate the next state snapshot fast without tracing _, _, _, err := api.eth.blockchain.Processor().Process(block, statedb, vm.Config{}) @@ -340,7 +280,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl break } // Finalize the state so any modifications are written to the trie - root, err := statedb.CommitTo(db, true) + root, err := statedb.Commit(true) if err != nil { failed = err break @@ -349,26 +289,14 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl failed = err break } - // After every N blocks, prune the database to only retain relevant data - if (number-start.NumberU64())%4096 == 0 { - // Wait until currently pending trace jobs finish - for atomic.LoadUint64(&complete) != number { - select { - case <-time.After(100 * time.Millisecond): - case <-notifier.Closed(): - return - } - } - // No more concurrent access at this point, prune the database - var ( - nodes = db.memdb.Len() - start = time.Now() - ) - db.Prune(root) - log.Info("Pruned tracer state entries", "deleted", nodes-db.memdb.Len(), "left", db.memdb.Len(), "elapsed", time.Since(start)) - - statedb, _ = state.New(root, state.NewDatabase(db)) + // Reference the trie twice, once for us, once for the trancer + database.TrieDB().Reference(root, common.Hash{}) + if number >= origin { + database.TrieDB().Reference(root, common.Hash{}) } + // Dereference all past tries we ourselves are done working with + database.TrieDB().Dereference(proot, common.Hash{}) + proot = root } }() @@ -387,12 +315,14 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl } done[uint64(result.Block)] = result + // Dereference any paret tries held in memory by this task + database.TrieDB().Dereference(res.rootref, common.Hash{}) + // Stream completed traces to the user, aborting on the first error for result, ok := done[next]; ok; result, ok = done[next] { if len(result.Traces) > 0 || next == end.NumberU64() { notifier.Notify(sub.ID, result) } - atomic.StoreUint64(&complete, next) delete(done, next) next++ } @@ -544,18 +474,14 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (* } // Otherwise try to reexec blocks until we find a state or reach our limit origin := block.NumberU64() + database := state.NewDatabase(api.eth.ChainDb()) - memdb, _ := ethdb.NewMemDatabase() - db := &ephemeralDatabase{ - diskdb: api.eth.ChainDb(), - memdb: memdb, - } for i := uint64(0); i < reexec; i++ { block = api.eth.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1) if block == nil { break } - if statedb, err = state.New(block.Root(), state.NewDatabase(db)); err == nil { + if statedb, err = state.New(block.Root(), database); err == nil { break } } @@ -571,6 +497,7 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (* var ( start = time.Now() logged time.Time + proot common.Hash ) for block.NumberU64() < origin { // Print progress logs if long enough time elapsed @@ -587,26 +514,18 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (* return nil, err } // Finalize the state so any modifications are written to the trie - root, err := statedb.CommitTo(db, true) + root, err := statedb.Commit(true) if err != nil { return nil, err } if err := statedb.Reset(root); err != nil { return nil, err } - // After every N blocks, prune the database to only retain relevant data - if block.NumberU64()%4096 == 0 || block.NumberU64() == origin { - var ( - nodes = db.memdb.Len() - begin = time.Now() - ) - db.Prune(root) - log.Info("Pruned tracer state entries", "deleted", nodes-db.memdb.Len(), "left", db.memdb.Len(), "elapsed", time.Since(begin)) - - statedb, _ = state.New(root, state.NewDatabase(db)) - } + database.TrieDB().Reference(root, common.Hash{}) + database.TrieDB().Dereference(proot, common.Hash{}) + proot = root } - log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start)) + log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start), "size", database.TrieDB().Size()) return statedb, nil } diff --git a/eth/backend.go b/eth/backend.go index bcd724c0c..94aad2310 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -144,9 +144,11 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { } core.WriteBlockChainVersion(chainDb, core.BlockChainVersion) } - - vmConfig := vm.Config{EnablePreimageRecording: config.EnablePreimageRecording} - eth.blockchain, err = core.NewBlockChain(chainDb, eth.chainConfig, eth.engine, vmConfig) + var ( + vmConfig = vm.Config{EnablePreimageRecording: config.EnablePreimageRecording} + cacheConfig = &core.CacheConfig{Disabled: config.NoPruning, TrieNodeLimit: config.TrieCache, TrieTimeLimit: config.TrieTimeout} + ) + eth.blockchain, err = core.NewBlockChain(chainDb, cacheConfig, eth.chainConfig, eth.engine, vmConfig) if err != nil { return nil, err } diff --git a/eth/config.go b/eth/config.go index 2158c71ba..dd7f42c7d 100644 --- a/eth/config.go +++ b/eth/config.go @@ -22,6 +22,7 @@ import ( "os/user" "path/filepath" "runtime" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" @@ -44,7 +45,9 @@ var DefaultConfig = Config{ }, NetworkId: 1, LightPeers: 100, - DatabaseCache: 128, + DatabaseCache: 768, + TrieCache: 256, + TrieTimeout: 5 * time.Minute, GasPrice: big.NewInt(18 * params.Shannon), TxPool: core.DefaultTxPoolConfig, @@ -78,6 +81,7 @@ type Config struct { // Protocol options NetworkId uint64 // Network ID to use for selecting peers to connect to SyncMode downloader.SyncMode + NoPruning bool // Light client options LightServ int `toml:",omitempty"` // Maximum percentage of time allowed for serving LES requests @@ -87,6 +91,8 @@ type Config struct { SkipBcVersionCheck bool `toml:"-"` DatabaseHandles int `toml:"-"` DatabaseCache int + TrieCache int + TrieTimeout time.Duration // Mining-related options Etherbase common.Address `toml:",omitempty"` diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 746c6a402..7f490d9e9 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -18,10 +18,8 @@ package downloader import ( - "crypto/rand" "errors" "fmt" - "math" "math/big" "sync" "sync/atomic" @@ -61,12 +59,11 @@ var ( maxHeadersProcess = 2048 // Number of header download results to import at once into the chain maxResultsProcess = 2048 // Number of content download results to import at once into the chain - fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync - fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected - fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it - fsPivotInterval = 256 // Number of headers out of which to randomize the pivot point - fsMinFullBlocks = 64 // Number of blocks to retrieve fully even in fast sync - fsCriticalTrials = uint32(32) // Number of times to retry in the cricical section before bailing + fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync + fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected + fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it + fsHeaderContCheck = 3 * time.Second // Time interval to check for header continuations during state download + fsMinFullBlocks = 64 // Number of blocks to retrieve fully even in fast sync ) var ( @@ -102,9 +99,6 @@ type Downloader struct { peers *peerSet // Set of active peers from which download can proceed stateDB ethdb.Database - fsPivotLock *types.Header // Pivot header on critical section entry (cannot change between retries) - fsPivotFails uint32 // Number of subsequent fast sync failures in the critical section - rttEstimate uint64 // Round trip time to target for download requests rttConfidence uint64 // Confidence in the estimated RTT (unit: millionths to allow atomic ops) @@ -124,6 +118,7 @@ type Downloader struct { synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing synchronising int32 notified int32 + committed int32 // Channels headerCh chan dataPack // [eth/62] Channel receiving inbound block headers @@ -156,7 +151,7 @@ type Downloader struct { // LightChain encapsulates functions required to synchronise a light chain. type LightChain interface { // HasHeader verifies a header's presence in the local chain. - HasHeader(h common.Hash, number uint64) bool + HasHeader(common.Hash, uint64) bool // GetHeaderByHash retrieves a header from the local chain. GetHeaderByHash(common.Hash) *types.Header @@ -179,7 +174,7 @@ type BlockChain interface { LightChain // HasBlockAndState verifies block and associated states' presence in the local chain. - HasBlockAndState(common.Hash) bool + HasBlockAndState(common.Hash, uint64) bool // GetBlockByHash retrieves a block from the local chain. GetBlockByHash(common.Hash) *types.Block @@ -391,9 +386,7 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode // Set the requested sync mode, unless it's forbidden d.mode = mode - if d.mode == FastSync && atomic.LoadUint32(&d.fsPivotFails) >= fsCriticalTrials { - d.mode = FullSync - } + // Retrieve the origin peer and initiate the downloading process p := d.peers.Peer(id) if p == nil { @@ -441,57 +434,40 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I d.syncStatsChainHeight = height d.syncStatsLock.Unlock() - // Initiate the sync using a concurrent header and content retrieval algorithm + // Ensure our origin point is below any fast sync pivot point pivot := uint64(0) - switch d.mode { - case LightSync: - pivot = height - case FastSync: - // Calculate the new fast/slow sync pivot point - if d.fsPivotLock == nil { - pivotOffset, err := rand.Int(rand.Reader, big.NewInt(int64(fsPivotInterval))) - if err != nil { - panic(fmt.Sprintf("Failed to access crypto random source: %v", err)) - } - if height > uint64(fsMinFullBlocks)+pivotOffset.Uint64() { - pivot = height - uint64(fsMinFullBlocks) - pivotOffset.Uint64() - } + if d.mode == FastSync { + if height <= uint64(fsMinFullBlocks) { + origin = 0 } else { - // Pivot point locked in, use this and do not pick a new one! - pivot = d.fsPivotLock.Number.Uint64() - } - // If the point is below the origin, move origin back to ensure state download - if pivot < origin { - if pivot > 0 { + pivot = height - uint64(fsMinFullBlocks) + if pivot <= origin { origin = pivot - 1 - } else { - origin = 0 } } - log.Debug("Fast syncing until pivot block", "pivot", pivot) } - d.queue.Prepare(origin+1, d.mode, pivot, latest) + d.committed = 1 + if d.mode == FastSync && pivot != 0 { + d.committed = 0 + } + // Initiate the sync using a concurrent header and content retrieval algorithm + d.queue.Prepare(origin+1, d.mode) if d.syncInitHook != nil { d.syncInitHook(origin, height) } fetchers := []func() error{ - func() error { return d.fetchHeaders(p, origin+1) }, // Headers are always retrieved - func() error { return d.fetchBodies(origin + 1) }, // Bodies are retrieved during normal and fast sync - func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync - func() error { return d.processHeaders(origin+1, td) }, + func() error { return d.fetchHeaders(p, origin+1, pivot) }, // Headers are always retrieved + func() error { return d.fetchBodies(origin + 1) }, // Bodies are retrieved during normal and fast sync + func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync + func() error { return d.processHeaders(origin+1, pivot, td) }, } if d.mode == FastSync { fetchers = append(fetchers, func() error { return d.processFastSyncContent(latest) }) } else if d.mode == FullSync { fetchers = append(fetchers, d.processFullSyncContent) } - err = d.spawnSync(fetchers) - if err != nil && d.mode == FastSync && d.fsPivotLock != nil { - // If sync failed in the critical section, bump the fail counter. - atomic.AddUint32(&d.fsPivotFails, 1) - } - return err + return d.spawnSync(fetchers) } // spawnSync runs d.process and all given fetcher functions to completion in @@ -671,7 +647,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err continue } // Otherwise check if we already know the header or not - if (d.mode == FullSync && d.blockchain.HasBlockAndState(headers[i].Hash())) || (d.mode != FullSync && d.lightchain.HasHeader(headers[i].Hash(), headers[i].Number.Uint64())) { + if (d.mode == FullSync && d.blockchain.HasBlockAndState(headers[i].Hash(), headers[i].Number.Uint64())) || (d.mode != FullSync && d.lightchain.HasHeader(headers[i].Hash(), headers[i].Number.Uint64())) { number, hash = headers[i].Number.Uint64(), headers[i].Hash() // If every header is known, even future ones, the peer straight out lied about its head @@ -736,7 +712,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err arrived = true // Modify the search interval based on the response - if (d.mode == FullSync && !d.blockchain.HasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.lightchain.HasHeader(headers[0].Hash(), headers[0].Number.Uint64())) { + if (d.mode == FullSync && !d.blockchain.HasBlockAndState(headers[0].Hash(), headers[0].Number.Uint64())) || (d.mode != FullSync && !d.lightchain.HasHeader(headers[0].Hash(), headers[0].Number.Uint64())) { end = check break } @@ -774,7 +750,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err // other peers are only accepted if they map cleanly to the skeleton. If no one // can fill in the skeleton - not even the origin peer - it's assumed invalid and // the origin is dropped. -func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error { +func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, pivot uint64) error { p.log.Debug("Directing header downloads", "origin", from) defer p.log.Debug("Header download terminated") @@ -825,6 +801,18 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error { } // If no more headers are inbound, notify the content fetchers and return if packet.Items() == 0 { + // Don't abort header fetches while the pivot is downloading + if atomic.LoadInt32(&d.committed) == 0 && pivot <= from { + p.log.Debug("No headers, waiting for pivot commit") + select { + case <-time.After(fsHeaderContCheck): + getHeaders(from) + continue + case <-d.cancelCh: + return errCancelHeaderFetch + } + } + // Pivot done (or not in fast sync) and no more headers, terminate the process p.log.Debug("No more headers available") select { case d.headerProcCh <- nil: @@ -1129,10 +1117,8 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv } if request.From > 0 { peer.log.Trace("Requesting new batch of data", "type", kind, "from", request.From) - } else if len(request.Headers) > 0 { - peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Headers), "from", request.Headers[0].Number) } else { - peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Hashes)) + peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Headers), "from", request.Headers[0].Number) } // Fetch the chunk and make sure any errors return the hashes to the queue if fetchHook != nil { @@ -1160,10 +1146,7 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv // processHeaders takes batches of retrieved headers from an input channel and // keeps processing and scheduling them into the header chain and downloader's // queue until the stream ends or a failure occurs. -func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { - // Calculate the pivoting point for switching from fast to slow sync - pivot := d.queue.FastSyncPivot() - +func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) error { // Keep a count of uncertain headers to roll back rollback := []*types.Header{} defer func() { @@ -1188,19 +1171,6 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { "header", fmt.Sprintf("%d->%d", lastHeader, d.lightchain.CurrentHeader().Number), "fast", fmt.Sprintf("%d->%d", lastFastBlock, curFastBlock), "block", fmt.Sprintf("%d->%d", lastBlock, curBlock)) - - // If we're already past the pivot point, this could be an attack, thread carefully - if rollback[len(rollback)-1].Number.Uint64() > pivot { - // If we didn't ever fail, lock in the pivot header (must! not! change!) - if atomic.LoadUint32(&d.fsPivotFails) == 0 { - for _, header := range rollback { - if header.Number.Uint64() == pivot { - log.Warn("Fast-sync pivot locked in", "number", pivot, "hash", header.Hash()) - d.fsPivotLock = header - } - } - } - } } }() @@ -1302,13 +1272,6 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { rollback = append(rollback[:0], rollback[len(rollback)-fsHeaderSafetyNet:]...) } } - // If we're fast syncing and just pulled in the pivot, make sure it's the one locked in - if d.mode == FastSync && d.fsPivotLock != nil && chunk[0].Number.Uint64() <= pivot && chunk[len(chunk)-1].Number.Uint64() >= pivot { - if pivot := chunk[int(pivot-chunk[0].Number.Uint64())]; pivot.Hash() != d.fsPivotLock.Hash() { - log.Warn("Pivot doesn't match locked in one", "remoteNumber", pivot.Number, "remoteHash", pivot.Hash(), "localNumber", d.fsPivotLock.Number, "localHash", d.fsPivotLock.Hash()) - return errInvalidChain - } - } // Unless we're doing light chains, schedule the headers for associated content retrieval if d.mode == FullSync || d.mode == FastSync { // If we've reached the allowed number of pending headers, stall a bit @@ -1343,7 +1306,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { // processFullSyncContent takes fetch results from the queue and imports them into the chain. func (d *Downloader) processFullSyncContent() error { for { - results := d.queue.WaitResults() + results := d.queue.Results(true) if len(results) == 0 { return nil } @@ -1357,30 +1320,28 @@ func (d *Downloader) processFullSyncContent() error { } func (d *Downloader) importBlockResults(results []*fetchResult) error { - for len(results) != 0 { - // Check for any termination requests. This makes clean shutdown faster. - select { - case <-d.quitCh: - return errCancelContentProcessing - default: - } - // Retrieve the a batch of results to import - items := int(math.Min(float64(len(results)), float64(maxResultsProcess))) - first, last := results[0].Header, results[items-1].Header - log.Debug("Inserting downloaded chain", "items", len(results), - "firstnum", first.Number, "firsthash", first.Hash(), - "lastnum", last.Number, "lasthash", last.Hash(), - ) - blocks := make([]*types.Block, items) - for i, result := range results[:items] { - blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) - } - if index, err := d.blockchain.InsertChain(blocks); err != nil { - log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err) - return errInvalidChain - } - // Shift the results to the next batch - results = results[items:] + // Check for any early termination requests + if len(results) == 0 { + return nil + } + select { + case <-d.quitCh: + return errCancelContentProcessing + default: + } + // Retrieve the a batch of results to import + first, last := results[0].Header, results[len(results)-1].Header + log.Debug("Inserting downloaded chain", "items", len(results), + "firstnum", first.Number, "firsthash", first.Hash(), + "lastnum", last.Number, "lasthash", last.Hash(), + ) + blocks := make([]*types.Block, len(results)) + for i, result := range results { + blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) + } + if index, err := d.blockchain.InsertChain(blocks); err != nil { + log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err) + return errInvalidChain } return nil } @@ -1388,35 +1349,92 @@ func (d *Downloader) importBlockResults(results []*fetchResult) error { // processFastSyncContent takes fetch results from the queue and writes them to the // database. It also controls the synchronisation of state nodes of the pivot block. func (d *Downloader) processFastSyncContent(latest *types.Header) error { - // Start syncing state of the reported head block. - // This should get us most of the state of the pivot block. + // Start syncing state of the reported head block. This should get us most of + // the state of the pivot block. stateSync := d.syncState(latest.Root) defer stateSync.Cancel() go func() { - if err := stateSync.Wait(); err != nil { + if err := stateSync.Wait(); err != nil && err != errCancelStateFetch { d.queue.Close() // wake up WaitResults } }() - - pivot := d.queue.FastSyncPivot() + // Figure out the ideal pivot block. Note, that this goalpost may move if the + // sync takes long enough for the chain head to move significantly. + pivot := uint64(0) + if height := latest.Number.Uint64(); height > uint64(fsMinFullBlocks) { + pivot = height - uint64(fsMinFullBlocks) + } + // To cater for moving pivot points, track the pivot block and subsequently + // accumulated download results separatey. + var ( + oldPivot *fetchResult // Locked in pivot block, might change eventually + oldTail []*fetchResult // Downloaded content after the pivot + ) for { - results := d.queue.WaitResults() + // Wait for the next batch of downloaded data to be available, and if the pivot + // block became stale, move the goalpost + results := d.queue.Results(oldPivot == nil) // Block if we're not monitoring pivot staleness if len(results) == 0 { - return stateSync.Cancel() + // If pivot sync is done, stop + if oldPivot == nil { + return stateSync.Cancel() + } + // If sync failed, stop + select { + case <-d.cancelCh: + return stateSync.Cancel() + default: + } } if d.chainInsertHook != nil { d.chainInsertHook(results) } + if oldPivot != nil { + results = append(append([]*fetchResult{oldPivot}, oldTail...), results...) + } + // Split around the pivot block and process the two sides via fast/full sync + if atomic.LoadInt32(&d.committed) == 0 { + latest = results[len(results)-1].Header + if height := latest.Number.Uint64(); height > pivot+2*uint64(fsMinFullBlocks) { + log.Warn("Pivot became stale, moving", "old", pivot, "new", height-uint64(fsMinFullBlocks)) + pivot = height - uint64(fsMinFullBlocks) + } + } P, beforeP, afterP := splitAroundPivot(pivot, results) if err := d.commitFastSyncData(beforeP, stateSync); err != nil { return err } if P != nil { - stateSync.Cancel() - if err := d.commitPivotBlock(P); err != nil { - return err + // If new pivot block found, cancel old state retrieval and restart + if oldPivot != P { + stateSync.Cancel() + + stateSync = d.syncState(P.Header.Root) + defer stateSync.Cancel() + go func() { + if err := stateSync.Wait(); err != nil && err != errCancelStateFetch { + d.queue.Close() // wake up WaitResults + } + }() + oldPivot = P + } + // Wait for completion, occasionally checking for pivot staleness + select { + case <-stateSync.done: + if stateSync.err != nil { + return stateSync.err + } + if err := d.commitPivotBlock(P); err != nil { + return err + } + oldPivot = nil + + case <-time.After(time.Second): + oldTail = afterP + continue } } + // Fast sync done, pivot commit done, full import if err := d.importBlockResults(afterP); err != nil { return err } @@ -1439,52 +1457,49 @@ func splitAroundPivot(pivot uint64, results []*fetchResult) (p *fetchResult, bef } func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *stateSync) error { - for len(results) != 0 { - // Check for any termination requests. - select { - case <-d.quitCh: - return errCancelContentProcessing - case <-stateSync.done: - if err := stateSync.Wait(); err != nil { - return err - } - default: + // Check for any early termination requests + if len(results) == 0 { + return nil + } + select { + case <-d.quitCh: + return errCancelContentProcessing + case <-stateSync.done: + if err := stateSync.Wait(); err != nil { + return err } - // Retrieve the a batch of results to import - items := int(math.Min(float64(len(results)), float64(maxResultsProcess))) - first, last := results[0].Header, results[items-1].Header - log.Debug("Inserting fast-sync blocks", "items", len(results), - "firstnum", first.Number, "firsthash", first.Hash(), - "lastnumn", last.Number, "lasthash", last.Hash(), - ) - blocks := make([]*types.Block, items) - receipts := make([]types.Receipts, items) - for i, result := range results[:items] { - blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) - receipts[i] = result.Receipts - } - if index, err := d.blockchain.InsertReceiptChain(blocks, receipts); err != nil { - log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err) - return errInvalidChain - } - // Shift the results to the next batch - results = results[items:] + default: + } + // Retrieve the a batch of results to import + first, last := results[0].Header, results[len(results)-1].Header + log.Debug("Inserting fast-sync blocks", "items", len(results), + "firstnum", first.Number, "firsthash", first.Hash(), + "lastnumn", last.Number, "lasthash", last.Hash(), + ) + blocks := make([]*types.Block, len(results)) + receipts := make([]types.Receipts, len(results)) + for i, result := range results { + blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) + receipts[i] = result.Receipts + } + if index, err := d.blockchain.InsertReceiptChain(blocks, receipts); err != nil { + log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err) + return errInvalidChain } return nil } func (d *Downloader) commitPivotBlock(result *fetchResult) error { - b := types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) - // Sync the pivot block state. This should complete reasonably quickly because - // we've already synced up to the reported head block state earlier. - if err := d.syncState(b.Root()).Wait(); err != nil { + block := types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles) + log.Debug("Committing fast sync pivot as new head", "number", block.Number(), "hash", block.Hash()) + if _, err := d.blockchain.InsertReceiptChain([]*types.Block{block}, []types.Receipts{result.Receipts}); err != nil { return err } - log.Debug("Committing fast sync pivot as new head", "number", b.Number(), "hash", b.Hash()) - if _, err := d.blockchain.InsertReceiptChain([]*types.Block{b}, []types.Receipts{result.Receipts}); err != nil { + if err := d.blockchain.FastSyncCommitHead(block.Hash()); err != nil { return err } - return d.blockchain.FastSyncCommitHead(b.Hash()) + atomic.StoreInt32(&d.committed, 1) + return nil } // DeliverHeaders injects a new batch of block headers received from a remote diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index e9c7b6170..d94d55f11 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -28,7 +28,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" @@ -45,8 +44,8 @@ var ( // Reduce some of the parameters to make the tester faster. func init() { MaxForkAncestry = uint64(10000) - blockCacheLimit = 1024 - fsCriticalTrials = 10 + blockCacheItems = 1024 + fsHeaderContCheck = 500 * time.Millisecond } // downloadTester is a test simulator for mocking out local block chain. @@ -223,7 +222,7 @@ func (dl *downloadTester) HasHeader(hash common.Hash, number uint64) bool { } // HasBlockAndState checks if a block and associated state is present in the testers canonical chain. -func (dl *downloadTester) HasBlockAndState(hash common.Hash) bool { +func (dl *downloadTester) HasBlockAndState(hash common.Hash, number uint64) bool { block := dl.GetBlockByHash(hash) if block == nil { return false @@ -293,7 +292,7 @@ func (dl *downloadTester) CurrentFastBlock() *types.Block { func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error { // For now only check that the state trie is correct if block := dl.GetBlockByHash(hash); block != nil { - _, err := trie.NewSecure(block.Root(), dl.stateDb, 0) + _, err := trie.NewSecure(block.Root(), trie.NewDatabase(dl.stateDb), 0) return err } return fmt.Errorf("non existent block: %x", hash[:4]) @@ -619,28 +618,22 @@ func assertOwnChain(t *testing.T, tester *downloadTester, length int) { // number of items of the various chain components. func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) { // Initialize the counters for the first fork - headers, blocks := lengths[0], lengths[0] + headers, blocks, receipts := lengths[0], lengths[0], lengths[0]-fsMinFullBlocks - minReceipts, maxReceipts := lengths[0]-fsMinFullBlocks-fsPivotInterval, lengths[0]-fsMinFullBlocks - if minReceipts < 0 { - minReceipts = 1 - } - if maxReceipts < 0 { - maxReceipts = 1 + if receipts < 0 { + receipts = 1 } // Update the counters for each subsequent fork for _, length := range lengths[1:] { headers += length - common blocks += length - common - - minReceipts += length - common - fsMinFullBlocks - fsPivotInterval - maxReceipts += length - common - fsMinFullBlocks + receipts += length - common - fsMinFullBlocks } switch tester.downloader.mode { case FullSync: - minReceipts, maxReceipts = 1, 1 + receipts = 1 case LightSync: - blocks, minReceipts, maxReceipts = 1, 1, 1 + blocks, receipts = 1, 1 } if hs := len(tester.ownHeaders); hs != headers { t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers) @@ -648,11 +641,12 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng if bs := len(tester.ownBlocks); bs != blocks { t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks) } - if rs := len(tester.ownReceipts); rs < minReceipts || rs > maxReceipts { - t.Fatalf("synchronised receipts mismatch: have %v, want between [%v, %v]", rs, minReceipts, maxReceipts) + if rs := len(tester.ownReceipts); rs != receipts { + t.Fatalf("synchronised receipts mismatch: have %v, want %v", rs, receipts) } // Verify the state trie too for fast syncs - if tester.downloader.mode == FastSync { + /*if tester.downloader.mode == FastSync { + pivot := uint64(0) var index int if pivot := int(tester.downloader.queue.fastSyncPivot); pivot < common { index = pivot @@ -660,11 +654,11 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot) } if index > 0 { - if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(tester.stateDb)); statedb == nil || err != nil { + if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(trie.NewDatabase(tester.stateDb))); statedb == nil || err != nil { t.Fatalf("state reconstruction failed: %v", err) } } - } + }*/ } // Tests that simple synchronization against a canonical chain works correctly. @@ -684,7 +678,7 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a small enough block chain to download - targetBlocks := blockCacheLimit - 15 + targetBlocks := blockCacheItems - 15 hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) @@ -710,7 +704,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a long block chain to download and the tester - targetBlocks := 8 * blockCacheLimit + targetBlocks := 8 * blockCacheItems hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) @@ -745,9 +739,9 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { cached = len(tester.downloader.queue.blockDonePool) if mode == FastSync { if receipts := len(tester.downloader.queue.receiptDonePool); receipts < cached { - if tester.downloader.queue.resultCache[receipts].Header.Number.Uint64() < tester.downloader.queue.fastSyncPivot { - cached = receipts - } + //if tester.downloader.queue.resultCache[receipts].Header.Number.Uint64() < tester.downloader.queue.fastSyncPivot { + cached = receipts + //} } } frozen = int(atomic.LoadUint32(&blocked)) @@ -755,7 +749,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { tester.downloader.queue.lock.Unlock() tester.lock.Unlock() - if cached == blockCacheLimit || retrieved+cached+frozen == targetBlocks+1 { + if cached == blockCacheItems || retrieved+cached+frozen == targetBlocks+1 { break } } @@ -765,8 +759,8 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { tester.lock.RLock() retrieved = len(tester.ownBlocks) tester.lock.RUnlock() - if cached != blockCacheLimit && retrieved+cached+frozen != targetBlocks+1 { - t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheLimit, retrieved, frozen, targetBlocks+1) + if cached != blockCacheItems && retrieved+cached+frozen != targetBlocks+1 { + t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheItems, retrieved, frozen, targetBlocks+1) } // Permit the blocked blocks to import if atomic.LoadUint32(&blocked) > 0 { @@ -974,7 +968,7 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a small enough block chain to download and the tester - targetBlocks := blockCacheLimit - 15 + targetBlocks := blockCacheItems - 15 if targetBlocks >= MaxHashFetch { targetBlocks = MaxHashFetch - 15 } @@ -1016,12 +1010,12 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) { // Create various peers with various parts of the chain targetPeers := 8 - targetBlocks := targetPeers*blockCacheLimit - 15 + targetBlocks := targetPeers*blockCacheItems - 15 hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) for i := 0; i < targetPeers; i++ { id := fmt.Sprintf("peer #%d", i) - tester.newPeer(id, protocol, hashes[i*blockCacheLimit:], headers, blocks, receipts) + tester.newPeer(id, protocol, hashes[i*blockCacheItems:], headers, blocks, receipts) } if err := tester.sync("peer #0", nil, mode); err != nil { t.Fatalf("failed to synchronise blocks: %v", err) @@ -1045,7 +1039,7 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a small enough block chain to download - targetBlocks := blockCacheLimit - 15 + targetBlocks := blockCacheItems - 15 hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) // Create peers of every type @@ -1084,7 +1078,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a block chain to download - targetBlocks := 2*blockCacheLimit - 15 + targetBlocks := 2*blockCacheItems - 15 hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) @@ -1110,8 +1104,8 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) { bodiesNeeded++ } } - for hash, receipt := range receipts { - if mode == FastSync && len(receipt) > 0 && headers[hash].Number.Uint64() <= tester.downloader.queue.fastSyncPivot { + for _, receipt := range receipts { + if mode == FastSync && len(receipt) > 0 { receiptsNeeded++ } } @@ -1139,7 +1133,7 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a small enough block chain to download - targetBlocks := blockCacheLimit - 15 + targetBlocks := blockCacheItems - 15 hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) // Attempt a full sync with an attacker feeding gapped headers @@ -1174,7 +1168,7 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a small enough block chain to download - targetBlocks := blockCacheLimit - 15 + targetBlocks := blockCacheItems - 15 hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) // Attempt a full sync with an attacker feeding shifted headers @@ -1208,7 +1202,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a small enough block chain to download - targetBlocks := 3*fsHeaderSafetyNet + fsPivotInterval + fsMinFullBlocks + targetBlocks := 3*fsHeaderSafetyNet + 256 + fsMinFullBlocks hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) // Attempt to sync with an attacker that feeds junk during the fast sync phase. @@ -1248,7 +1242,6 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { tester.newPeer("withhold-attack", protocol, hashes, headers, blocks, receipts) missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1 - tester.downloader.fsPivotFails = 0 tester.downloader.syncInitHook = func(uint64, uint64) { for i := missing; i <= len(hashes); i++ { delete(tester.peerHeaders["withhold-attack"], hashes[len(hashes)-i]) @@ -1267,8 +1260,6 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { t.Errorf("fast sync pivot block #%d not rolled back", head) } } - tester.downloader.fsPivotFails = fsCriticalTrials - // Synchronise with the valid peer and make sure sync succeeds. Since the last // rollback should also disable fast syncing for this process, verify that we // did a fresh full sync. Note, we can't assert anything about the receipts @@ -1383,7 +1374,7 @@ func testSyncProgress(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a small enough block chain to download - targetBlocks := blockCacheLimit - 15 + targetBlocks := blockCacheItems - 15 hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) // Set a sync init hook to catch progress changes @@ -1532,7 +1523,7 @@ func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a small enough block chain to download - targetBlocks := blockCacheLimit - 15 + targetBlocks := blockCacheItems - 15 hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) // Set a sync init hook to catch progress changes @@ -1609,7 +1600,7 @@ func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) { defer tester.terminate() // Create a small block chain - targetBlocks := blockCacheLimit - 15 + targetBlocks := blockCacheItems - 15 hashes, headers, blocks, receipts := tester.makeChain(targetBlocks+3, 0, tester.genesis, nil, false) // Set a sync init hook to catch progress changes @@ -1697,6 +1688,7 @@ func TestDeliverHeadersHang(t *testing.T) { type floodingTestPeer struct { peer Peer tester *downloadTester + pend sync.WaitGroup } func (ftp *floodingTestPeer) Head() (common.Hash, *big.Int) { return ftp.peer.Head() } @@ -1717,9 +1709,12 @@ func (ftp *floodingTestPeer) RequestHeadersByNumber(from uint64, count, skip int deliveriesDone := make(chan struct{}, 500) for i := 0; i < cap(deliveriesDone); i++ { peer := fmt.Sprintf("fake-peer%d", i) + ftp.pend.Add(1) + go func() { ftp.tester.downloader.DeliverHeaders(peer, []*types.Header{{}, {}, {}, {}}) deliveriesDone <- struct{}{} + ftp.pend.Done() }() } // Deliver the actual requested headers. @@ -1751,110 +1746,15 @@ func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) { // Whenever the downloader requests headers, flood it with // a lot of unrequested header deliveries. tester.downloader.peers.peers["peer"].peer = &floodingTestPeer{ - tester.downloader.peers.peers["peer"].peer, - tester, + peer: tester.downloader.peers.peers["peer"].peer, + tester: tester, } if err := tester.sync("peer", nil, mode); err != nil { - t.Errorf("sync failed: %v", err) + t.Errorf("test %d: sync failed: %v", i, err) } tester.terminate() + + // Flush all goroutines to prevent messing with subsequent tests + tester.downloader.peers.peers["peer"].peer.(*floodingTestPeer).pend.Wait() } } - -// Tests that if fast sync aborts in the critical section, it can restart a few -// times before giving up. -// We use data driven subtests to manage this so that it will be parallel on its own -// and not with the other tests, avoiding intermittent failures. -func TestFastCriticalRestarts(t *testing.T) { - testCases := []struct { - protocol int - progress bool - }{ - {63, false}, - {64, false}, - {63, true}, - {64, true}, - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("protocol %d progress %v", tc.protocol, tc.progress), func(t *testing.T) { - testFastCriticalRestarts(t, tc.protocol, tc.progress) - }) - } -} - -func testFastCriticalRestarts(t *testing.T, protocol int, progress bool) { - t.Parallel() - - tester := newTester() - defer tester.terminate() - - // Create a large enough blockchin to actually fast sync on - targetBlocks := fsMinFullBlocks + 2*fsPivotInterval - 15 - hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false) - - // Create a tester peer with a critical section header missing (force failures) - tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) - delete(tester.peerHeaders["peer"], hashes[fsMinFullBlocks-1]) - tester.downloader.dropPeer = func(id string) {} // We reuse the same "faulty" peer throughout the test - - // Remove all possible pivot state roots and slow down replies (test failure resets later) - for i := 0; i < fsPivotInterval; i++ { - tester.peerMissingStates["peer"][headers[hashes[fsMinFullBlocks+i]].Root] = true - } - (tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).setDelay(500 * time.Millisecond) // Enough to reach the critical section - - // Synchronise with the peer a few times and make sure they fail until the retry limit - for i := 0; i < int(fsCriticalTrials)-1; i++ { - // Attempt a sync and ensure it fails properly - if err := tester.sync("peer", nil, FastSync); err == nil { - t.Fatalf("failing fast sync succeeded: %v", err) - } - time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain - - // If it's the first failure, pivot should be locked => reenable all others to detect pivot changes - if i == 0 { - time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain - if tester.downloader.fsPivotLock == nil { - time.Sleep(400 * time.Millisecond) // Make sure the first huge timeout expires too - t.Fatalf("pivot block not locked in after critical section failure") - } - tester.lock.Lock() - tester.peerHeaders["peer"][hashes[fsMinFullBlocks-1]] = headers[hashes[fsMinFullBlocks-1]] - tester.peerMissingStates["peer"] = map[common.Hash]bool{tester.downloader.fsPivotLock.Root: true} - (tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).setDelay(0) - tester.lock.Unlock() - } - } - // Return all nodes if we're testing fast sync progression - if progress { - tester.lock.Lock() - tester.peerMissingStates["peer"] = map[common.Hash]bool{} - tester.lock.Unlock() - - if err := tester.sync("peer", nil, FastSync); err != nil { - t.Fatalf("failed to synchronise blocks in progressed fast sync: %v", err) - } - time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain - - if fails := atomic.LoadUint32(&tester.downloader.fsPivotFails); fails != 1 { - t.Fatalf("progressed pivot trial count mismatch: have %v, want %v", fails, 1) - } - assertOwnChain(t, tester, targetBlocks+1) - } else { - if err := tester.sync("peer", nil, FastSync); err == nil { - t.Fatalf("succeeded to synchronise blocks in failed fast sync") - } - time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain - - if fails := atomic.LoadUint32(&tester.downloader.fsPivotFails); fails != fsCriticalTrials { - t.Fatalf("failed pivot trial count mismatch: have %v, want %v", fails, fsCriticalTrials) - } - } - // Retry limit exhausted, downloader will switch to full sync, should succeed - if err := tester.sync("peer", nil, FastSync); err != nil { - t.Fatalf("failed to synchronise blocks in slow sync: %v", err) - } - // Note, we can't assert the chain here because the test asserter assumes sync - // completed using a single mode of operation, whereas fast-then-slow can result - // in arbitrary intermediate state that's not cleanly verifiable. -} diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index 6926f1d8c..a1a70e46e 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -32,7 +32,11 @@ import ( "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) -var blockCacheLimit = 8192 // Maximum number of blocks to cache before throttling the download +var ( + blockCacheItems = 8192 // Maximum number of blocks to cache before throttling the download + blockCacheMemory = 64 * 1024 * 1024 // Maximum amount of memory to use for block caching + blockCacheSizeWeight = 0.1 // Multiplier to approximate the average block size based on past ones +) var ( errNoFetchesPending = errors.New("no fetches pending") @@ -41,17 +45,17 @@ var ( // fetchRequest is a currently running data retrieval operation. type fetchRequest struct { - Peer *peerConnection // Peer to which the request was sent - From uint64 // [eth/62] Requested chain element index (used for skeleton fills only) - Hashes map[common.Hash]int // [eth/61] Requested hashes with their insertion index (priority) - Headers []*types.Header // [eth/62] Requested headers, sorted by request order - Time time.Time // Time when the request was made + Peer *peerConnection // Peer to which the request was sent + From uint64 // [eth/62] Requested chain element index (used for skeleton fills only) + Headers []*types.Header // [eth/62] Requested headers, sorted by request order + Time time.Time // Time when the request was made } // fetchResult is a struct collecting partial results from data fetchers until // all outstanding pieces complete and the result as a whole can be processed. type fetchResult struct { - Pending int // Number of data fetches still pending + Pending int // Number of data fetches still pending + Hash common.Hash // Hash of the header to prevent recalculating Header *types.Header Uncles []*types.Header @@ -61,12 +65,10 @@ type fetchResult struct { // queue represents hashes that are either need fetching or are being fetched type queue struct { - mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching - fastSyncPivot uint64 // Block number where the fast sync pivots into archive synchronisation mode - - headerHead common.Hash // [eth/62] Hash of the last queued header to verify order + mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching // Headers are "special", they download in batches, supported by a skeleton chain + headerHead common.Hash // [eth/62] Hash of the last queued header to verify order headerTaskPool map[uint64]*types.Header // [eth/62] Pending header retrieval tasks, mapping starting indexes to skeleton headers headerTaskQueue *prque.Prque // [eth/62] Priority queue of the skeleton indexes to fetch the filling headers for headerPeerMiss map[string]map[uint64]struct{} // [eth/62] Set of per-peer header batches known to be unavailable @@ -87,8 +89,9 @@ type queue struct { receiptPendPool map[string]*fetchRequest // [eth/63] Currently pending receipt retrieval operations receiptDonePool map[common.Hash]struct{} // [eth/63] Set of the completed receipt fetches - resultCache []*fetchResult // Downloaded but not yet delivered fetch results - resultOffset uint64 // Offset of the first cached fetch result in the block chain + resultCache []*fetchResult // Downloaded but not yet delivered fetch results + resultOffset uint64 // Offset of the first cached fetch result in the block chain + resultSize common.StorageSize // Approximate size of a block (exponential moving average) lock *sync.Mutex active *sync.Cond @@ -109,7 +112,7 @@ func newQueue() *queue { receiptTaskQueue: prque.New(), receiptPendPool: make(map[string]*fetchRequest), receiptDonePool: make(map[common.Hash]struct{}), - resultCache: make([]*fetchResult, blockCacheLimit), + resultCache: make([]*fetchResult, blockCacheItems), active: sync.NewCond(lock), lock: lock, } @@ -122,10 +125,8 @@ func (q *queue) Reset() { q.closed = false q.mode = FullSync - q.fastSyncPivot = 0 q.headerHead = common.Hash{} - q.headerPendPool = make(map[string]*fetchRequest) q.blockTaskPool = make(map[common.Hash]*types.Header) @@ -138,7 +139,7 @@ func (q *queue) Reset() { q.receiptPendPool = make(map[string]*fetchRequest) q.receiptDonePool = make(map[common.Hash]struct{}) - q.resultCache = make([]*fetchResult, blockCacheLimit) + q.resultCache = make([]*fetchResult, blockCacheItems) q.resultOffset = 0 } @@ -214,27 +215,13 @@ func (q *queue) Idle() bool { return (queued + pending + cached) == 0 } -// FastSyncPivot retrieves the currently used fast sync pivot point. -func (q *queue) FastSyncPivot() uint64 { - q.lock.Lock() - defer q.lock.Unlock() - - return q.fastSyncPivot -} - // ShouldThrottleBlocks checks if the download should be throttled (active block (body) // fetches exceed block cache). func (q *queue) ShouldThrottleBlocks() bool { q.lock.Lock() defer q.lock.Unlock() - // Calculate the currently in-flight block (body) requests - pending := 0 - for _, request := range q.blockPendPool { - pending += len(request.Hashes) + len(request.Headers) - } - // Throttle if more blocks (bodies) are in-flight than free space in the cache - return pending >= len(q.resultCache)-len(q.blockDonePool) + return q.resultSlots(q.blockPendPool, q.blockDonePool) <= 0 } // ShouldThrottleReceipts checks if the download should be throttled (active receipt @@ -243,13 +230,39 @@ func (q *queue) ShouldThrottleReceipts() bool { q.lock.Lock() defer q.lock.Unlock() - // Calculate the currently in-flight receipt requests - pending := 0 - for _, request := range q.receiptPendPool { - pending += len(request.Headers) + return q.resultSlots(q.receiptPendPool, q.receiptDonePool) <= 0 +} + +// resultSlots calculates the number of results slots available for requests +// whilst adhering to both the item and the memory limit too of the results +// cache. +func (q *queue) resultSlots(pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}) int { + // Calculate the maximum length capped by the memory limit + limit := len(q.resultCache) + if common.StorageSize(len(q.resultCache))*q.resultSize > common.StorageSize(blockCacheMemory) { + limit = int((common.StorageSize(blockCacheMemory) + q.resultSize - 1) / q.resultSize) } - // Throttle if more receipts are in-flight than free space in the cache - return pending >= len(q.resultCache)-len(q.receiptDonePool) + // Calculate the number of slots already finished + finished := 0 + for _, result := range q.resultCache[:limit] { + if result == nil { + break + } + if _, ok := donePool[result.Hash]; ok { + finished++ + } + } + // Calculate the number of slots currently downloading + pending := 0 + for _, request := range pendPool { + for _, header := range request.Headers { + if header.Number.Uint64() < q.resultOffset+uint64(limit) { + pending++ + } + } + } + // Return the free slots to distribute + return limit - finished - pending } // ScheduleSkeleton adds a batch of header retrieval tasks to the queue to fill @@ -323,8 +336,7 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header { q.blockTaskPool[hash] = header q.blockTaskQueue.Push(header, -float32(header.Number.Uint64())) - if q.mode == FastSync && header.Number.Uint64() <= q.fastSyncPivot { - // Fast phase of the fast sync, retrieve receipts too + if q.mode == FastSync { q.receiptTaskPool[hash] = header q.receiptTaskQueue.Push(header, -float32(header.Number.Uint64())) } @@ -335,18 +347,25 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header { return inserts } -// WaitResults retrieves and permanently removes a batch of fetch -// results from the cache. the result slice will be empty if the queue -// has been closed. -func (q *queue) WaitResults() []*fetchResult { +// Results retrieves and permanently removes a batch of fetch results from +// the cache. the result slice will be empty if the queue has been closed. +func (q *queue) Results(block bool) []*fetchResult { q.lock.Lock() defer q.lock.Unlock() + // Count the number of items available for processing nproc := q.countProcessableItems() for nproc == 0 && !q.closed { + if !block { + return nil + } q.active.Wait() nproc = q.countProcessableItems() } + // Since we have a batch limit, don't pull more into "dangling" memory + if nproc > maxResultsProcess { + nproc = maxResultsProcess + } results := make([]*fetchResult, nproc) copy(results, q.resultCache[:nproc]) if len(results) > 0 { @@ -363,6 +382,21 @@ func (q *queue) WaitResults() []*fetchResult { } // Advance the expected block number of the first cache entry. q.resultOffset += uint64(nproc) + + // Recalculate the result item weights to prevent memory exhaustion + for _, result := range results { + size := result.Header.Size() + for _, uncle := range result.Uncles { + size += uncle.Size() + } + for _, receipt := range result.Receipts { + size += receipt.Size() + } + for _, tx := range result.Transactions { + size += tx.Size() + } + q.resultSize = common.StorageSize(blockCacheSizeWeight)*size + (1-common.StorageSize(blockCacheSizeWeight))*q.resultSize + } } return results } @@ -370,21 +404,9 @@ func (q *queue) WaitResults() []*fetchResult { // countProcessableItems counts the processable items. func (q *queue) countProcessableItems() int { for i, result := range q.resultCache { - // Don't process incomplete or unavailable items. if result == nil || result.Pending > 0 { return i } - // Stop before processing the pivot block to ensure that - // resultCache has space for fsHeaderForceVerify items. Not - // doing this could leave us unable to download the required - // amount of headers. - if q.mode == FastSync && result.Header.Number.Uint64() == q.fastSyncPivot { - for j := 0; j < fsHeaderForceVerify; j++ { - if i+j+1 >= len(q.resultCache) || q.resultCache[i+j+1] == nil { - return i - } - } - } } return len(q.resultCache) } @@ -473,10 +495,8 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common return nil, false, nil } // Calculate an upper limit on the items we might fetch (i.e. throttling) - space := len(q.resultCache) - len(donePool) - for _, request := range pendPool { - space -= len(request.Headers) - } + space := q.resultSlots(pendPool, donePool) + // Retrieve a batch of tasks, skipping previously failed ones send := make([]*types.Header, 0, count) skip := make([]*types.Header, 0) @@ -484,6 +504,7 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common progress := false for proc := 0; proc < space && len(send) < count && !taskQueue.Empty(); proc++ { header := taskQueue.PopItem().(*types.Header) + hash := header.Hash() // If we're the first to request this task, initialise the result container index := int(header.Number.Int64() - int64(q.resultOffset)) @@ -493,18 +514,19 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common } if q.resultCache[index] == nil { components := 1 - if q.mode == FastSync && header.Number.Uint64() <= q.fastSyncPivot { + if q.mode == FastSync { components = 2 } q.resultCache[index] = &fetchResult{ Pending: components, + Hash: hash, Header: header, } } // If this fetch task is a noop, skip this fetch operation if isNoop(header) { - donePool[header.Hash()] = struct{}{} - delete(taskPool, header.Hash()) + donePool[hash] = struct{}{} + delete(taskPool, hash) space, proc = space-1, proc-1 q.resultCache[index].Pending-- @@ -512,7 +534,7 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common continue } // Otherwise unless the peer is known not to have the data, add to the retrieve list - if p.Lacks(header.Hash()) { + if p.Lacks(hash) { skip = append(skip, header) } else { send = append(send, header) @@ -565,9 +587,6 @@ func (q *queue) cancel(request *fetchRequest, taskQueue *prque.Prque, pendPool m if request.From > 0 { taskQueue.Push(request.From, -float32(request.From)) } - for hash, index := range request.Hashes { - taskQueue.Push(hash, float32(index)) - } for _, header := range request.Headers { taskQueue.Push(header, -float32(header.Number.Uint64())) } @@ -640,18 +659,11 @@ func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest, if request.From > 0 { taskQueue.Push(request.From, -float32(request.From)) } - for hash, index := range request.Hashes { - taskQueue.Push(hash, float32(index)) - } for _, header := range request.Headers { taskQueue.Push(header, -float32(header.Number.Uint64())) } // Add the peer to the expiry report along the the number of failed requests - expirations := len(request.Hashes) - if expirations < len(request.Headers) { - expirations = len(request.Headers) - } - expiries[id] = expirations + expiries[id] = len(request.Headers) } } // Remove the expired requests from the pending pool @@ -828,14 +840,16 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ failure = err break } - donePool[header.Hash()] = struct{}{} + hash := header.Hash() + + donePool[hash] = struct{}{} q.resultCache[index].Pending-- useful = true accepted++ // Clean up a successful fetch request.Headers[i] = nil - delete(taskPool, header.Hash()) + delete(taskPool, hash) } // Return all failed or missing fetches to the queue for _, header := range request.Headers { @@ -860,7 +874,7 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ // Prepare configures the result cache to allow accepting and caching inbound // fetch results. -func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64, head *types.Header) { +func (q *queue) Prepare(offset uint64, mode SyncMode) { q.lock.Lock() defer q.lock.Unlock() @@ -868,6 +882,5 @@ func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64, head *types. if q.resultOffset < offset { q.resultOffset = offset } - q.fastSyncPivot = pivot q.mode = mode } diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index 937828b94..9cc65a208 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -20,7 +20,6 @@ import ( "fmt" "hash" "sync" - "sync/atomic" "time" "github.com/ethereum/go-ethereum/common" @@ -294,6 +293,9 @@ func (s *stateSync) loop() error { case <-s.cancel: return errCancelStateFetch + case <-s.d.cancelCh: + return errCancelStateFetch + case req := <-s.deliver: // Response, disconnect or timeout triggered, drop the peer if stalling log.Trace("Received node data response", "peer", req.peer.id, "count", len(req.response), "dropped", req.dropped, "timeout", !req.dropped && req.timedOut()) @@ -304,15 +306,11 @@ func (s *stateSync) loop() error { s.d.dropPeer(req.peer.id) } // Process all the received blobs and check for stale delivery - stale, err := s.process(req) - if err != nil { + if err := s.process(req); err != nil { log.Warn("Node data write error", "err", err) return err } - // The the delivery contains requested data, mark the node idle (otherwise it's a timed out delivery) - if !stale { - req.peer.SetNodeDataIdle(len(req.response)) - } + req.peer.SetNodeDataIdle(len(req.response)) } } return s.commit(true) @@ -352,6 +350,7 @@ func (s *stateSync) assignTasks() { case s.d.trackStateReq <- req: req.peer.FetchNodeData(req.items) case <-s.cancel: + case <-s.d.cancelCh: } } } @@ -390,7 +389,7 @@ func (s *stateSync) fillTasks(n int, req *stateReq) { // process iterates over a batch of delivered state data, injecting each item // into a running state sync, re-queuing any items that were requested but not // delivered. -func (s *stateSync) process(req *stateReq) (bool, error) { +func (s *stateSync) process(req *stateReq) error { // Collect processing stats and update progress if valid data was received duplicate, unexpected := 0, 0 @@ -401,7 +400,7 @@ func (s *stateSync) process(req *stateReq) (bool, error) { }(time.Now()) // Iterate over all the delivered data and inject one-by-one into the trie - progress, stale := false, len(req.response) > 0 + progress := false for _, blob := range req.response { prog, hash, err := s.processNodeData(blob) @@ -415,20 +414,12 @@ func (s *stateSync) process(req *stateReq) (bool, error) { case trie.ErrAlreadyProcessed: duplicate++ default: - return stale, fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err) + return fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err) } - // If the node delivered a requested item, mark the delivery non-stale if _, ok := req.tasks[hash]; ok { delete(req.tasks, hash) - stale = false } } - // If we're inside the critical section, reset fail counter since we progressed. - if progress && atomic.LoadUint32(&s.d.fsPivotFails) > 1 { - log.Trace("Fast-sync progressed, resetting fail counter", "previous", atomic.LoadUint32(&s.d.fsPivotFails)) - atomic.StoreUint32(&s.d.fsPivotFails, 1) // Don't ever reset to 0, as that will unlock the pivot block - } - // Put unfulfilled tasks back into the retry queue npeers := s.d.peers.Len() for hash, task := range req.tasks { @@ -441,12 +432,12 @@ func (s *stateSync) process(req *stateReq) (bool, error) { // If we've requested the node too many times already, it may be a malicious // sync where nobody has the right data. Abort. if len(task.attempts) >= npeers { - return stale, fmt.Errorf("state node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers) + return fmt.Errorf("state node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers) } // Missing item, place into the retry queue. s.tasks[hash] = task } - return stale, nil + return nil } // processNodeData tries to inject a trie node data blob delivered from a remote diff --git a/eth/handler.go b/eth/handler.go index fcd53c5a6..c2426544f 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -71,7 +71,6 @@ type ProtocolManager struct { txpool txPool blockchain *core.BlockChain - chaindb ethdb.Database chainconfig *params.ChainConfig maxPeers int @@ -106,7 +105,6 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne eventMux: mux, txpool: txpool, blockchain: blockchain, - chaindb: chaindb, chainconfig: config, peers: newPeerSet(), newPeerCh: make(chan *peer), @@ -538,7 +536,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrDecode, "msg %v: %v", msg, err) } // Retrieve the requested state entry, stopping if enough was found - if entry, err := pm.chaindb.Get(hash.Bytes()); err == nil { + if entry, err := pm.blockchain.TrieNode(hash); err == nil { data = append(data, entry) bytes += len(entry) } @@ -576,7 +574,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrDecode, "msg %v: %v", msg, err) } // Retrieve the requested block's receipts, skipping if unknown to us - results := core.GetBlockReceipts(pm.chaindb, hash, core.GetBlockNumber(pm.chaindb, hash)) + results := pm.blockchain.GetReceiptsByHash(hash) if results == nil { if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { continue diff --git a/eth/handler_test.go b/eth/handler_test.go index 9a02eddfb..e336dfa28 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -56,7 +56,7 @@ func TestProtocolCompatibility(t *testing.T) { for i, tt := range tests { ProtocolVersions = []uint{tt.version} - pm, err := newTestProtocolManager(tt.mode, 0, nil, nil) + pm, _, err := newTestProtocolManager(tt.mode, 0, nil, nil) if pm != nil { defer pm.Stop() } @@ -71,7 +71,7 @@ func TestGetBlockHeaders62(t *testing.T) { testGetBlockHeaders(t, 62) } func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) } func testGetBlockHeaders(t *testing.T, protocol int) { - pm := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil) + pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil) peer, _ := newTestPeer("peer", protocol, pm, true) defer peer.close() @@ -230,7 +230,7 @@ func TestGetBlockBodies62(t *testing.T) { testGetBlockBodies(t, 62) } func TestGetBlockBodies63(t *testing.T) { testGetBlockBodies(t, 63) } func testGetBlockBodies(t *testing.T, protocol int) { - pm := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil) + pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil) peer, _ := newTestPeer("peer", protocol, pm, true) defer peer.close() @@ -337,13 +337,13 @@ func testGetNodeData(t *testing.T, protocol int) { } } // Assemble the test environment - pm := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil) + pm, db := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil) peer, _ := newTestPeer("peer", protocol, pm, true) defer peer.close() // Fetch for now the entire chain db hashes := []common.Hash{} - for _, key := range pm.chaindb.(*ethdb.MemDatabase).Keys() { + for _, key := range db.Keys() { if len(key) == len(common.Hash{}) { hashes = append(hashes, common.BytesToHash(key)) } @@ -429,7 +429,7 @@ func testGetReceipt(t *testing.T, protocol int) { } } // Assemble the test environment - pm := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil) + pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil) peer, _ := newTestPeer("peer", protocol, pm, true) defer peer.close() @@ -439,7 +439,7 @@ func testGetReceipt(t *testing.T, protocol int) { block := pm.blockchain.GetBlockByNumber(i) hashes = append(hashes, block.Hash()) - receipts = append(receipts, core.GetBlockReceipts(pm.chaindb, block.Hash(), block.NumberU64())) + receipts = append(receipts, pm.blockchain.GetReceiptsByHash(block.Hash())) } // Send the hash request and verify the response p2p.Send(peer.app, 0x0f, hashes) @@ -472,7 +472,7 @@ func testDAOChallenge(t *testing.T, localForked, remoteForked bool, timeout bool config = ¶ms.ChainConfig{DAOForkBlock: big.NewInt(1), DAOForkSupport: localForked} gspec = &core.Genesis{Config: config} genesis = gspec.MustCommit(db) - blockchain, _ = core.NewBlockChain(db, config, pow, vm.Config{}) + blockchain, _ = core.NewBlockChain(db, nil, config, pow, vm.Config{}) ) pm, err := NewProtocolManager(config, downloader.FullSync, DefaultConfig.NetworkId, evmux, new(testTxPool), pow, blockchain, db) if err != nil { diff --git a/eth/helper_test.go b/eth/helper_test.go index 9a4dc9010..2b05cea80 100644 --- a/eth/helper_test.go +++ b/eth/helper_test.go @@ -49,7 +49,7 @@ var ( // newTestProtocolManager creates a new protocol manager for testing purposes, // with the given number of blocks already known, and potential notification // channels for different events. -func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, error) { +func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, *ethdb.MemDatabase, error) { var ( evmux = new(event.TypeMux) engine = ethash.NewFaker() @@ -59,7 +59,7 @@ func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func Alloc: core.GenesisAlloc{testBank: {Balance: big.NewInt(1000000)}}, } genesis = gspec.MustCommit(db) - blockchain, _ = core.NewBlockChain(db, gspec.Config, engine, vm.Config{}) + blockchain, _ = core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}) ) chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator) if _, err := blockchain.InsertChain(chain); err != nil { @@ -68,22 +68,22 @@ func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func pm, err := NewProtocolManager(gspec.Config, mode, DefaultConfig.NetworkId, evmux, &testTxPool{added: newtx}, engine, blockchain, db) if err != nil { - return nil, err + return nil, nil, err } pm.Start(1000) - return pm, nil + return pm, db, nil } // newTestProtocolManagerMust creates a new protocol manager for testing purposes, // with the given number of blocks already known, and potential notification // channels for different events. In case of an error, the constructor force- // fails the test. -func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) *ProtocolManager { - pm, err := newTestProtocolManager(mode, blocks, generator, newtx) +func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, *ethdb.MemDatabase) { + pm, db, err := newTestProtocolManager(mode, blocks, generator, newtx) if err != nil { t.Fatalf("Failed to create protocol manager: %v", err) } - return pm + return pm, db } // testTxPool is a fake, helper transaction pool for testing purposes diff --git a/eth/protocol_test.go b/eth/protocol_test.go index 7cbcba571..b2f93d8dd 100644 --- a/eth/protocol_test.go +++ b/eth/protocol_test.go @@ -41,7 +41,7 @@ func TestStatusMsgErrors62(t *testing.T) { testStatusMsgErrors(t, 62) } func TestStatusMsgErrors63(t *testing.T) { testStatusMsgErrors(t, 63) } func testStatusMsgErrors(t *testing.T, protocol int) { - pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) + pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) var ( genesis = pm.blockchain.Genesis() head = pm.blockchain.CurrentHeader() @@ -98,7 +98,7 @@ func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) } func testRecvTransactions(t *testing.T, protocol int) { txAdded := make(chan []*types.Transaction) - pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded) + pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded) pm.acceptTxs = 1 // mark synced to accept transactions p, _ := newTestPeer("peer", protocol, pm, true) defer pm.Stop() @@ -125,7 +125,7 @@ func TestSendTransactions62(t *testing.T) { testSendTransactions(t, 62) } func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) } func testSendTransactions(t *testing.T, protocol int) { - pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) + pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) defer pm.Stop() // Fill the pool with big transactions. diff --git a/eth/sync_test.go b/eth/sync_test.go index 9eaa1156f..88c10c7f7 100644 --- a/eth/sync_test.go +++ b/eth/sync_test.go @@ -30,12 +30,12 @@ import ( // imported into the blockchain. func TestFastSyncDisabling(t *testing.T) { // Create a pristine protocol manager, check that fast sync is left enabled - pmEmpty := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil) + pmEmpty, _ := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil) if atomic.LoadUint32(&pmEmpty.fastSync) == 0 { t.Fatalf("fast sync disabled on pristine blockchain") } // Create a full protocol manager, check that fast sync gets disabled - pmFull := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil) + pmFull, _ := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil) if atomic.LoadUint32(&pmFull.fastSync) == 1 { t.Fatalf("fast sync not disabled on non-empty blockchain") } diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index a4cba7a4d..314086335 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -808,7 +808,7 @@ func (s *PublicBlockChainAPI) rpcOutputBlock(b *types.Block, inclTx bool, fullTx "difficulty": (*hexutil.Big)(head.Difficulty), "totalDifficulty": (*hexutil.Big)(s.b.GetTd(b.Hash())), "extraData": hexutil.Bytes(head.Extra), - "size": hexutil.Uint64(uint64(b.Size().Int64())), + "size": hexutil.Uint64(b.Size()), "gasLimit": hexutil.Uint64(head.GasLimit), "gasUsed": hexutil.Uint64(head.GasUsed), "timestamp": (*hexutil.Big)(head.Time), diff --git a/les/handler.go b/les/handler.go index 8cd37c7ab..5c93133fb 100644 --- a/les/handler.go +++ b/les/handler.go @@ -18,7 +18,6 @@ package les import ( - "bytes" "encoding/binary" "errors" "fmt" @@ -78,6 +77,7 @@ type BlockChain interface { GetHeaderByHash(hash common.Hash) *types.Header CurrentHeader() *types.Header GetTd(hash common.Hash, number uint64) *big.Int + State() (*state.StateDB, error) InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error) Rollback(chain []common.Hash) GetHeaderByNumber(number uint64) *types.Header @@ -579,17 +579,19 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { for _, req := range req.Reqs { // Retrieve the requested state entry, stopping if enough was found if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { - if trie, _ := trie.New(header.Root, pm.chainDb); trie != nil { - sdata := trie.Get(req.AccKey) - var acc state.Account - if err := rlp.DecodeBytes(sdata, &acc); err == nil { - entry, _ := pm.chainDb.Get(acc.CodeHash) - if bytes+len(entry) >= softResponseLimit { - break - } - data = append(data, entry) - bytes += len(entry) - } + statedb, err := pm.blockchain.State() + if err != nil { + continue + } + account, err := pm.getAccount(statedb, header.Root, common.BytesToHash(req.AccKey)) + if err != nil { + continue + } + code, _ := statedb.Database().TrieDB().Node(common.BytesToHash(account.CodeHash)) + + data = append(data, code) + if bytes += len(code); bytes >= softResponseLimit { + break } } } @@ -701,25 +703,29 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrRequestRejected, "") } for _, req := range req.Reqs { - if bytes >= softResponseLimit { - break - } // Retrieve the requested state entry, stopping if enough was found if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { - if tr, _ := trie.New(header.Root, pm.chainDb); tr != nil { - if len(req.AccKey) > 0 { - sdata := tr.Get(req.AccKey) - tr = nil - var acc state.Account - if err := rlp.DecodeBytes(sdata, &acc); err == nil { - tr, _ = trie.New(acc.Root, pm.chainDb) - } + statedb, err := pm.blockchain.State() + if err != nil { + continue + } + var trie state.Trie + if len(req.AccKey) > 0 { + account, err := pm.getAccount(statedb, header.Root, common.BytesToHash(req.AccKey)) + if err != nil { + continue } - if tr != nil { - var proof light.NodeList - tr.Prove(req.Key, 0, &proof) - proofs = append(proofs, proof) - bytes += proof.DataSize() + trie, _ = statedb.Database().OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root) + } else { + trie, _ = statedb.Database().OpenTrie(header.Root) + } + if trie != nil { + var proof light.NodeList + trie.Prove(req.Key, 0, &proof) + + proofs = append(proofs, proof) + if bytes += proof.DataSize(); bytes >= softResponseLimit { + break } } } @@ -740,9 +746,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } // Gather state data until the fetch or network limits is reached var ( - lastBHash common.Hash - lastAccKey []byte - tr, str *trie.Trie + lastBHash common.Hash + statedb *state.StateDB + root common.Hash ) reqCnt := len(req.Reqs) if reject(uint64(reqCnt), MaxProofsFetch) { @@ -752,36 +758,37 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { nodes := light.NewNodeSet() for _, req := range req.Reqs { + // Look up the state belonging to the request + if statedb == nil || req.BHash != lastBHash { + statedb, root, lastBHash = nil, common.Hash{}, req.BHash + + if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { + statedb, _ = pm.blockchain.State() + root = header.Root + } + } + if statedb == nil { + continue + } + // Pull the account or storage trie of the request + var trie state.Trie + if len(req.AccKey) > 0 { + account, err := pm.getAccount(statedb, root, common.BytesToHash(req.AccKey)) + if err != nil { + continue + } + trie, _ = statedb.Database().OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root) + } else { + trie, _ = statedb.Database().OpenTrie(root) + } + if trie == nil { + continue + } + // Prove the user's request from the account or stroage trie + trie.Prove(req.Key, req.FromLevel, nodes) if nodes.DataSize() >= softResponseLimit { break } - if tr == nil || req.BHash != lastBHash { - if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil { - tr, _ = trie.New(header.Root, pm.chainDb) - } else { - tr = nil - } - lastBHash = req.BHash - str = nil - } - if tr != nil { - if len(req.AccKey) > 0 { - if str == nil || !bytes.Equal(req.AccKey, lastAccKey) { - sdata := tr.Get(req.AccKey) - str = nil - var acc state.Account - if err := rlp.DecodeBytes(sdata, &acc); err == nil { - str, _ = trie.New(acc.Root, pm.chainDb) - } - lastAccKey = common.CopyBytes(req.AccKey) - } - if str != nil { - str.Prove(req.Key, req.FromLevel, nodes) - } - } else { - tr.Prove(req.Key, req.FromLevel, nodes) - } - } } proofs := nodes.NodeList() bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) @@ -849,23 +856,29 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) { return errResp(ErrRequestRejected, "") } - trieDb := ethdb.NewTable(pm.chainDb, light.ChtTablePrefix) for _, req := range req.Reqs { - if bytes >= softResponseLimit { - break - } - if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil { sectionHead := core.GetCanonicalHash(pm.chainDb, req.ChtNum*light.ChtV1Frequency-1) if root := light.GetChtRoot(pm.chainDb, req.ChtNum-1, sectionHead); root != (common.Hash{}) { - if tr, _ := trie.New(root, trieDb); tr != nil { - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], req.BlockNum) - var proof light.NodeList - tr.Prove(encNumber[:], 0, &proof) - proofs = append(proofs, ChtResp{Header: header, Proof: proof}) - bytes += proof.DataSize() + estHeaderRlpSize + statedb, err := pm.blockchain.State() + if err != nil { + continue } + trie, err := statedb.Database().OpenTrie(root) + if err != nil { + continue + } + var encNumber [8]byte + binary.BigEndian.PutUint64(encNumber[:], req.BlockNum) + + var proof light.NodeList + trie.Prove(encNumber[:], 0, &proof) + + proofs = append(proofs, ChtResp{Header: header, Proof: proof}) + if bytes += proof.DataSize() + estHeaderRlpSize; bytes >= softResponseLimit { + break + } + } } } @@ -897,25 +910,21 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { lastIdx uint64 lastType uint root common.Hash - tr *trie.Trie + statedb *state.StateDB + trie state.Trie ) nodes := light.NewNodeSet() for _, req := range req.Reqs { - if nodes.DataSize()+auxBytes >= softResponseLimit { - break - } - if tr == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx { - var prefix string - root, prefix = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx) - if root != (common.Hash{}) { - if t, err := trie.New(root, ethdb.NewTable(pm.chainDb, prefix)); err == nil { - tr = t + if trie == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx { + statedb, trie, lastType, lastIdx = nil, nil, req.HelperTrieType, req.TrieIdx + + if root, _ = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx); root != (common.Hash{}) { + if statedb, _ = pm.blockchain.State(); statedb != nil { + trie, _ = statedb.Database().OpenTrie(root) } } - lastType = req.HelperTrieType - lastIdx = req.TrieIdx } if req.AuxReq == auxRoot { var data []byte @@ -925,8 +934,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { auxData = append(auxData, data) auxBytes += len(data) } else { - if tr != nil { - tr.Prove(req.Key, req.FromLevel, nodes) + if trie != nil { + trie.Prove(req.Key, req.FromLevel, nodes) } if req.AuxReq != 0 { data := pm.getHelperTrieAuxData(req) @@ -934,6 +943,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { auxBytes += len(data) } } + if nodes.DataSize()+auxBytes >= softResponseLimit { + break + } } proofs := nodes.NodeList() bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost) @@ -1090,6 +1102,23 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return nil } +// getAccount retrieves an account from the state based at root. +func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (state.Account, error) { + trie, err := trie.New(root, statedb.Database().TrieDB()) + if err != nil { + return state.Account{}, err + } + blob, err := trie.TryGet(hash[:]) + if err != nil { + return state.Account{}, err + } + var account state.Account + if err = rlp.DecodeBytes(blob, &account); err != nil { + return state.Account{}, err + } + return account, nil +} + // getHelperTrie returns the post-processed trie root for the given trie ID and section index func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) { switch id { diff --git a/les/handler_test.go b/les/handler_test.go index 10e5499a3..e5446c031 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -359,7 +359,7 @@ func testGetProofs(t *testing.T, protocol int) { for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ { header := bc.GetHeaderByNumber(i) root := header.Root - trie, _ := trie.New(root, db) + trie, _ := trie.New(root, trie.NewDatabase(db)) for _, acc := range accounts { req := ProofReq{ diff --git a/les/helper_test.go b/les/helper_test.go index 1c1de64ad..bf08e1e2f 100644 --- a/les/helper_test.go +++ b/les/helper_test.go @@ -146,7 +146,7 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor if lightSync { chain, _ = light.NewLightChain(odr, gspec.Config, engine) } else { - blockchain, _ := core.NewBlockChain(db, gspec.Config, engine, vm.Config{}) + blockchain, _ := core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}) gchain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator) if _, err := blockchain.InsertChain(gchain); err != nil { panic(err) diff --git a/les/odr_test.go b/les/odr_test.go index cf609be88..88e121cda 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -101,7 +101,6 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon res = append(res, rlp...) } } - return res } diff --git a/light/lightchain.go b/light/lightchain.go index f47957512..24529ef82 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -18,6 +18,7 @@ package light import ( "context" + "errors" "math/big" "sync" "sync/atomic" @@ -26,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" @@ -212,6 +214,11 @@ func (bc *LightChain) Genesis() *types.Block { return bc.genesisBlock } +// State returns a new mutable state based on the current HEAD block. +func (bc *LightChain) State() (*state.StateDB, error) { + return nil, errors.New("not implemented, needs client/server interface split") +} + // GetBody retrieves a block body (transactions and uncles) from the database // or ODR service by hash, caching it if found. func (self *LightChain) GetBody(ctx context.Context, hash common.Hash) (*types.Body, error) { diff --git a/light/nodeset.go b/light/nodeset.go index c530a4fbe..ffdb71bb7 100644 --- a/light/nodeset.go +++ b/light/nodeset.go @@ -22,8 +22,8 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" ) // NodeSet stores a set of trie nodes. It implements trie.Database and can also @@ -99,7 +99,7 @@ func (db *NodeSet) NodeList() NodeList { } // Store writes the contents of the set to the given database -func (db *NodeSet) Store(target trie.Database) { +func (db *NodeSet) Store(target ethdb.Putter) { db.lock.RLock() defer db.lock.RUnlock() @@ -108,11 +108,11 @@ func (db *NodeSet) Store(target trie.Database) { } } -// NodeList stores an ordered list of trie nodes. It implements trie.DatabaseWriter. +// NodeList stores an ordered list of trie nodes. It implements ethdb.Putter. type NodeList []rlp.RawValue // Store writes the contents of the list to the given database -func (n NodeList) Store(db trie.Database) { +func (n NodeList) Store(db ethdb.Putter) { for _, node := range n { db.Put(crypto.Keccak256(node), node) } diff --git a/light/odr_test.go b/light/odr_test.go index e3d07518a..d3f9374fd 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -74,7 +74,7 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error { case *ReceiptsRequest: req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash)) case *TrieRequest: - t, _ := trie.New(req.Id.Root, odr.sdb) + t, _ := trie.New(req.Id.Root, trie.NewDatabase(odr.sdb)) nodes := NewNodeSet() t.Prove(req.Key, 0, nodes) req.Proof = nodes @@ -239,7 +239,7 @@ func testChainOdr(t *testing.T, protocol int, fn odrTestFn) { ) gspec.MustCommit(ldb) // Assemble the test environment - blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}) + blockchain, _ := core.NewBlockChain(sdb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}) gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), sdb, 4, testChainGen) if _, err := blockchain.InsertChain(gchain); err != nil { t.Fatal(err) diff --git a/light/postprocess.go b/light/postprocess.go index 32dbc102b..bbac58d12 100644 --- a/light/postprocess.go +++ b/light/postprocess.go @@ -113,7 +113,8 @@ func StoreChtRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root common // ChtIndexerBackend implements core.ChainIndexerBackend type ChtIndexerBackend struct { - db, cdb ethdb.Database + diskdb ethdb.Database + triedb *trie.Database section, sectionSize uint64 lastHash common.Hash trie *trie.Trie @@ -121,8 +122,6 @@ type ChtIndexerBackend struct { // NewBloomTrieIndexer creates a BloomTrie chain indexer func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { - cdb := ethdb.NewTable(db, ChtTablePrefix) - idb := ethdb.NewTable(db, "chtIndex-") var sectionSize, confirmReq uint64 if clientMode { sectionSize = ChtFrequency @@ -131,17 +130,23 @@ func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { sectionSize = ChtV1Frequency confirmReq = HelperTrieProcessConfirmations } - return core.NewChainIndexer(db, idb, &ChtIndexerBackend{db: db, cdb: cdb, sectionSize: sectionSize}, sectionSize, confirmReq, time.Millisecond*100, "cht") + idb := ethdb.NewTable(db, "chtIndex-") + backend := &ChtIndexerBackend{ + diskdb: db, + triedb: trie.NewDatabase(ethdb.NewTable(db, ChtTablePrefix)), + sectionSize: sectionSize, + } + return core.NewChainIndexer(db, idb, backend, sectionSize, confirmReq, time.Millisecond*100, "cht") } // Reset implements core.ChainIndexerBackend func (c *ChtIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error { var root common.Hash if section > 0 { - root = GetChtRoot(c.db, section-1, lastSectionHead) + root = GetChtRoot(c.diskdb, section-1, lastSectionHead) } var err error - c.trie, err = trie.New(root, c.cdb) + c.trie, err = trie.New(root, c.triedb) c.section = section return err } @@ -151,7 +156,7 @@ func (c *ChtIndexerBackend) Process(header *types.Header) { hash, num := header.Hash(), header.Number.Uint64() c.lastHash = hash - td := core.GetTd(c.db, hash, num) + td := core.GetTd(c.diskdb, hash, num) if td == nil { panic(nil) } @@ -163,17 +168,16 @@ func (c *ChtIndexerBackend) Process(header *types.Header) { // Commit implements core.ChainIndexerBackend func (c *ChtIndexerBackend) Commit() error { - batch := c.cdb.NewBatch() - root, err := c.trie.CommitTo(batch) + root, err := c.trie.Commit(nil) if err != nil { return err - } else { - batch.Write() - if ((c.section+1)*c.sectionSize)%ChtFrequency == 0 { - log.Info("Storing CHT", "idx", c.section*c.sectionSize/ChtFrequency, "sectionHead", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root)) - } - StoreChtRoot(c.db, c.section, c.lastHash, root) } + c.triedb.Commit(root, false) + + if ((c.section+1)*c.sectionSize)%ChtFrequency == 0 { + log.Info("Storing CHT", "idx", c.section*c.sectionSize/ChtFrequency, "sectionHead", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root)) + } + StoreChtRoot(c.diskdb, c.section, c.lastHash, root) return nil } @@ -205,7 +209,8 @@ func StoreBloomTrieRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root // BloomTrieIndexerBackend implements core.ChainIndexerBackend type BloomTrieIndexerBackend struct { - db, cdb ethdb.Database + diskdb ethdb.Database + triedb *trie.Database section, parentSectionSize, bloomTrieRatio uint64 trie *trie.Trie sectionHeads []common.Hash @@ -213,9 +218,12 @@ type BloomTrieIndexerBackend struct { // NewBloomTrieIndexer creates a BloomTrie chain indexer func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer { - cdb := ethdb.NewTable(db, BloomTrieTablePrefix) + backend := &BloomTrieIndexerBackend{ + diskdb: db, + triedb: trie.NewDatabase(ethdb.NewTable(db, BloomTrieTablePrefix)), + } idb := ethdb.NewTable(db, "bltIndex-") - backend := &BloomTrieIndexerBackend{db: db, cdb: cdb} + var confirmReq uint64 if clientMode { backend.parentSectionSize = BloomTrieFrequency @@ -233,10 +241,10 @@ func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer func (b *BloomTrieIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error { var root common.Hash if section > 0 { - root = GetBloomTrieRoot(b.db, section-1, lastSectionHead) + root = GetBloomTrieRoot(b.diskdb, section-1, lastSectionHead) } var err error - b.trie, err = trie.New(root, b.cdb) + b.trie, err = trie.New(root, b.triedb) b.section = section return err } @@ -259,7 +267,7 @@ func (b *BloomTrieIndexerBackend) Commit() error { binary.BigEndian.PutUint64(encKey[2:10], b.section) var decomp []byte for j := uint64(0); j < b.bloomTrieRatio; j++ { - data, err := core.GetBloomBits(b.db, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j]) + data, err := core.GetBloomBits(b.diskdb, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j]) if err != nil { return err } @@ -279,17 +287,15 @@ func (b *BloomTrieIndexerBackend) Commit() error { b.trie.Delete(encKey[:]) } } - - batch := b.cdb.NewBatch() - root, err := b.trie.CommitTo(batch) + root, err := b.trie.Commit(nil) if err != nil { return err - } else { - batch.Write() - sectionHead := b.sectionHeads[b.bloomTrieRatio-1] - log.Info("Storing BloomTrie", "section", b.section, "sectionHead", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression ratio", float64(compSize)/float64(decompSize)) - StoreBloomTrieRoot(b.db, b.section, sectionHead, root) } + b.triedb.Commit(root, false) + + sectionHead := b.sectionHeads[b.bloomTrieRatio-1] + log.Info("Storing BloomTrie", "section", b.section, "sectionHead", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression ratio", float64(compSize)/float64(decompSize)) + StoreBloomTrieRoot(b.diskdb, b.section, sectionHead, root) return nil } diff --git a/light/trie.go b/light/trie.go index 7a9c86b98..c07e99461 100644 --- a/light/trie.go +++ b/light/trie.go @@ -18,12 +18,14 @@ package light import ( "context" + "errors" "fmt" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/trie" ) @@ -83,6 +85,10 @@ func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, er return len(code), err } +func (db *odrDatabase) TrieDB() *trie.Database { + return nil +} + type odrTrie struct { db *odrDatabase id *TrieID @@ -113,11 +119,11 @@ func (t *odrTrie) TryDelete(key []byte) error { }) } -func (t *odrTrie) CommitTo(db trie.DatabaseWriter) (common.Hash, error) { +func (t *odrTrie) Commit(onleaf trie.LeafCallback) (common.Hash, error) { if t.trie == nil { return t.id.Root, nil } - return t.trie.CommitTo(db) + return t.trie.Commit(onleaf) } func (t *odrTrie) Hash() common.Hash { @@ -135,13 +141,17 @@ func (t *odrTrie) GetKey(sha []byte) []byte { return nil } +func (t *odrTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error { + return errors.New("not implemented, needs client/server interface split") +} + // do tries and retries to execute a function until it returns with no error or // an error type other than MissingNodeError func (t *odrTrie) do(key []byte, fn func() error) error { for { var err error if t.trie == nil { - t.trie, err = trie.New(t.id.Root, t.db.backend.Database()) + t.trie, err = trie.New(t.id.Root, trie.NewDatabase(t.db.backend.Database())) } if err == nil { err = fn() @@ -167,7 +177,7 @@ func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator { // Open the actual non-ODR trie if that hasn't happened yet. if t.trie == nil { it.do(func() error { - t, err := trie.New(t.id.Root, t.db.backend.Database()) + t, err := trie.New(t.id.Root, trie.NewDatabase(t.db.backend.Database())) if err == nil { it.t.trie = t } diff --git a/light/trie_test.go b/light/trie_test.go index d99664718..0d6b2cc1d 100644 --- a/light/trie_test.go +++ b/light/trie_test.go @@ -40,7 +40,7 @@ func TestNodeIterator(t *testing.T) { genesis = gspec.MustCommit(fulldb) ) gspec.MustCommit(lightdb) - blockchain, _ := core.NewBlockChain(fulldb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}) + blockchain, _ := core.NewBlockChain(fulldb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}) gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), fulldb, 4, testChainGen) if _, err := blockchain.InsertChain(gchain); err != nil { panic(err) diff --git a/light/txpool_test.go b/light/txpool_test.go index b343f79b0..13d7d3ceb 100644 --- a/light/txpool_test.go +++ b/light/txpool_test.go @@ -88,7 +88,7 @@ func TestTxPool(t *testing.T) { ) gspec.MustCommit(ldb) // Assemble the test environment - blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}) + blockchain, _ := core.NewBlockChain(sdb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{}) gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), sdb, poolTestBlocks, txPoolTestChainGen) if _, err := blockchain.InsertChain(gchain); err != nil { panic(err) diff --git a/miner/worker.go b/miner/worker.go index 1520277e1..15395ae0b 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -309,7 +309,7 @@ func (self *worker) wait() { for _, log := range work.state.Logs() { log.BlockHash = block.Hash() } - stat, err := self.chain.WriteBlockAndState(block, work.receipts, work.state) + stat, err := self.chain.WriteBlockWithState(block, work.receipts, work.state) if err != nil { log.Error("Failed writing block to chain", "err", err) continue diff --git a/tests/block_test_util.go b/tests/block_test_util.go index 4bfd6433f..beba48483 100644 --- a/tests/block_test_util.go +++ b/tests/block_test_util.go @@ -110,7 +110,7 @@ func (t *BlockTest) Run() error { return fmt.Errorf("genesis block state root does not match test: computed=%x, test=%x", gblock.Root().Bytes()[:6], t.json.Genesis.StateRoot[:6]) } - chain, err := core.NewBlockChain(db, config, ethash.NewShared(), vm.Config{}) + chain, err := core.NewBlockChain(db, nil, config, ethash.NewShared(), vm.Config{}) if err != nil { return err } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index 78c05b024..18280d2a4 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -125,7 +125,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD if !ok { return nil, UnsupportedForkError{subtest.Fork} } - block, _ := t.genesis(config).ToBlock() + block := t.genesis(config).ToBlock(nil) db, _ := ethdb.NewMemDatabase() statedb := MakePreState(db, t.json.Pre) @@ -147,7 +147,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { return statedb, fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs) } - root, _ := statedb.CommitTo(db, config.IsEIP158(block.Number())) + root, _ := statedb.Commit(config.IsEIP158(block.Number())) if root != common.Hash(post.Root) { return statedb, fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root) } @@ -170,7 +170,7 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB } } // Commit and re-open to start with a clean state. - root, _ := statedb.CommitTo(db, false) + root, _ := statedb.Commit(false) statedb, _ = state.New(root, sdb) return statedb } diff --git a/trie/database.go b/trie/database.go new file mode 100644 index 000000000..d79120813 --- /dev/null +++ b/trie/database.go @@ -0,0 +1,355 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" +) + +// secureKeyPrefix is the database key prefix used to store trie node preimages. +var secureKeyPrefix = []byte("secure-key-") + +// secureKeyLength is the length of the above prefix + 32byte hash. +const secureKeyLength = 11 + 32 + +// DatabaseReader wraps the Get and Has method of a backing store for the trie. +type DatabaseReader interface { + // Get retrieves the value associated with key form the database. + Get(key []byte) (value []byte, err error) + + // Has retrieves whether a key is present in the database. + Has(key []byte) (bool, error) +} + +// Database is an intermediate write layer between the trie data structures and +// the disk database. The aim is to accumulate trie writes in-memory and only +// periodically flush a couple tries to disk, garbage collecting the remainder. +type Database struct { + diskdb ethdb.Database // Persistent storage for matured trie nodes + + nodes map[common.Hash]*cachedNode // Data and references relationships of a node + preimages map[common.Hash][]byte // Preimages of nodes from the secure trie + seckeybuf [secureKeyLength]byte // Ephemeral buffer for calculating preimage keys + + gctime time.Duration // Time spent on garbage collection since last commit + gcnodes uint64 // Nodes garbage collected since last commit + gcsize common.StorageSize // Data storage garbage collected since last commit + + nodesSize common.StorageSize // Storage size of the nodes cache + preimagesSize common.StorageSize // Storage size of the preimages cache + + lock sync.RWMutex +} + +// cachedNode is all the information we know about a single cached node in the +// memory database write layer. +type cachedNode struct { + blob []byte // Cached data block of the trie node + parents int // Number of live nodes referencing this one + children map[common.Hash]int // Children referenced by this nodes +} + +// NewDatabase creates a new trie database to store ephemeral trie content before +// its written out to disk or garbage collected. +func NewDatabase(diskdb ethdb.Database) *Database { + return &Database{ + diskdb: diskdb, + nodes: map[common.Hash]*cachedNode{ + {}: {children: make(map[common.Hash]int)}, + }, + preimages: make(map[common.Hash][]byte), + } +} + +// DiskDB retrieves the persistent storage backing the trie database. +func (db *Database) DiskDB() DatabaseReader { + return db.diskdb +} + +// Insert writes a new trie node to the memory database if it's yet unknown. The +// method will make a copy of the slice. +func (db *Database) Insert(hash common.Hash, blob []byte) { + db.lock.Lock() + defer db.lock.Unlock() + + db.insert(hash, blob) +} + +// insert is the private locked version of Insert. +func (db *Database) insert(hash common.Hash, blob []byte) { + if _, ok := db.nodes[hash]; ok { + return + } + db.nodes[hash] = &cachedNode{ + blob: common.CopyBytes(blob), + children: make(map[common.Hash]int), + } + db.nodesSize += common.StorageSize(common.HashLength + len(blob)) +} + +// insertPreimage writes a new trie node pre-image to the memory database if it's +// yet unknown. The method will make a copy of the slice. +// +// Note, this method assumes that the database's lock is held! +func (db *Database) insertPreimage(hash common.Hash, preimage []byte) { + if _, ok := db.preimages[hash]; ok { + return + } + db.preimages[hash] = common.CopyBytes(preimage) + db.preimagesSize += common.StorageSize(common.HashLength + len(preimage)) +} + +// Node retrieves a cached trie node from memory. If it cannot be found cached, +// the method queries the persistent database for the content. +func (db *Database) Node(hash common.Hash) ([]byte, error) { + // Retrieve the node from cache if available + db.lock.RLock() + node := db.nodes[hash] + db.lock.RUnlock() + + if node != nil { + return node.blob, nil + } + // Content unavailable in memory, attempt to retrieve from disk + return db.diskdb.Get(hash[:]) +} + +// preimage retrieves a cached trie node pre-image from memory. If it cannot be +// found cached, the method queries the persistent database for the content. +func (db *Database) preimage(hash common.Hash) ([]byte, error) { + // Retrieve the node from cache if available + db.lock.RLock() + preimage := db.preimages[hash] + db.lock.RUnlock() + + if preimage != nil { + return preimage, nil + } + // Content unavailable in memory, attempt to retrieve from disk + return db.diskdb.Get(db.secureKey(hash[:])) +} + +// secureKey returns the database key for the preimage of key, as an ephemeral +// buffer. The caller must not hold onto the return value because it will become +// invalid on the next call. +func (db *Database) secureKey(key []byte) []byte { + buf := append(db.seckeybuf[:0], secureKeyPrefix...) + buf = append(buf, key...) + return buf +} + +// Nodes retrieves the hashes of all the nodes cached within the memory database. +// This method is extremely expensive and should only be used to validate internal +// states in test code. +func (db *Database) Nodes() []common.Hash { + db.lock.RLock() + defer db.lock.RUnlock() + + var hashes = make([]common.Hash, 0, len(db.nodes)) + for hash := range db.nodes { + if hash != (common.Hash{}) { // Special case for "root" references/nodes + hashes = append(hashes, hash) + } + } + return hashes +} + +// Reference adds a new reference from a parent node to a child node. +func (db *Database) Reference(child common.Hash, parent common.Hash) { + db.lock.RLock() + defer db.lock.RUnlock() + + db.reference(child, parent) +} + +// reference is the private locked version of Reference. +func (db *Database) reference(child common.Hash, parent common.Hash) { + // If the node does not exist, it's a node pulled from disk, skip + node, ok := db.nodes[child] + if !ok { + return + } + // If the reference already exists, only duplicate for roots + if _, ok = db.nodes[parent].children[child]; ok && parent != (common.Hash{}) { + return + } + node.parents++ + db.nodes[parent].children[child]++ +} + +// Dereference removes an existing reference from a parent node to a child node. +func (db *Database) Dereference(child common.Hash, parent common.Hash) { + db.lock.Lock() + defer db.lock.Unlock() + + nodes, storage, start := len(db.nodes), db.nodesSize, time.Now() + db.dereference(child, parent) + + db.gcnodes += uint64(nodes - len(db.nodes)) + db.gcsize += storage - db.nodesSize + db.gctime += time.Since(start) + + log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.nodes), "size", storage-db.nodesSize, "time", time.Since(start), + "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.nodes), "livesize", db.nodesSize) +} + +// dereference is the private locked version of Dereference. +func (db *Database) dereference(child common.Hash, parent common.Hash) { + // Dereference the parent-child + node := db.nodes[parent] + + node.children[child]-- + if node.children[child] == 0 { + delete(node.children, child) + } + // If the node does not exist, it's a previously committed node. + node, ok := db.nodes[child] + if !ok { + return + } + // If there are no more references to the child, delete it and cascade + node.parents-- + if node.parents == 0 { + for hash := range node.children { + db.dereference(hash, child) + } + delete(db.nodes, child) + db.nodesSize -= common.StorageSize(common.HashLength + len(node.blob)) + } +} + +// Commit iterates over all the children of a particular node, writes them out +// to disk, forcefully tearing down all references in both directions. +// +// As a side effect, all pre-images accumulated up to this point are also written. +func (db *Database) Commit(node common.Hash, report bool) error { + // Create a database batch to flush persistent data out. It is important that + // outside code doesn't see an inconsistent state (referenced data removed from + // memory cache during commit but not yet in persistent storage). This is ensured + // by only uncaching existing data when the database write finalizes. + db.lock.RLock() + + start := time.Now() + batch := db.diskdb.NewBatch() + + // Move all of the accumulated preimages into a write batch + for hash, preimage := range db.preimages { + if err := batch.Put(db.secureKey(hash[:]), preimage); err != nil { + log.Error("Failed to commit preimage from trie database", "err", err) + db.lock.RUnlock() + return err + } + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + return err + } + batch.Reset() + } + } + // Move the trie itself into the batch, flushing if enough data is accumulated + nodes, storage := len(db.nodes), db.nodesSize+db.preimagesSize + if err := db.commit(node, batch); err != nil { + log.Error("Failed to commit trie from trie database", "err", err) + db.lock.RUnlock() + return err + } + // Write batch ready, unlock for readers during persistence + if err := batch.Write(); err != nil { + log.Error("Failed to write trie to disk", "err", err) + db.lock.RUnlock() + return err + } + db.lock.RUnlock() + + // Write successful, clear out the flushed data + db.lock.Lock() + defer db.lock.Unlock() + + db.preimages = make(map[common.Hash][]byte) + db.preimagesSize = 0 + + db.uncache(node) + + logger := log.Info + if !report { + logger = log.Debug + } + logger("Persisted trie from memory database", "nodes", nodes-len(db.nodes), "size", storage-db.nodesSize, "time", time.Since(start), + "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.nodes), "livesize", db.nodesSize) + + // Reset the garbage collection statistics + db.gcnodes, db.gcsize, db.gctime = 0, 0, 0 + + return nil +} + +// commit is the private locked version of Commit. +func (db *Database) commit(hash common.Hash, batch ethdb.Batch) error { + // If the node does not exist, it's a previously committed node + node, ok := db.nodes[hash] + if !ok { + return nil + } + for child := range node.children { + if err := db.commit(child, batch); err != nil { + return err + } + } + if err := batch.Put(hash[:], node.blob); err != nil { + return err + } + // If we've reached an optimal match size, commit and start over + if batch.ValueSize() >= ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + return err + } + batch.Reset() + } + return nil +} + +// uncache is the post-processing step of a commit operation where the already +// persisted trie is removed from the cache. The reason behind the two-phase +// commit is to ensure consistent data availability while moving from memory +// to disk. +func (db *Database) uncache(hash common.Hash) { + // If the node does not exist, we're done on this path + node, ok := db.nodes[hash] + if !ok { + return + } + // Otherwise uncache the node's subtries and remove the node itself too + for child := range node.children { + db.uncache(child) + } + delete(db.nodes, hash) + db.nodesSize -= common.StorageSize(common.HashLength + len(node.blob)) +} + +// Size returns the current storage size of the memory cache in front of the +// persistent database layer. +func (db *Database) Size() common.StorageSize { + db.lock.RLock() + defer db.lock.RUnlock() + + return db.nodesSize + db.preimagesSize +} diff --git a/trie/hasher.go b/trie/hasher.go index 4719aabf6..2fc44787a 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -27,21 +27,23 @@ import ( ) type hasher struct { - tmp *bytes.Buffer - sha hash.Hash - cachegen, cachelimit uint16 + tmp *bytes.Buffer + sha hash.Hash + cachegen uint16 + cachelimit uint16 + onleaf LeafCallback } -// hashers live in a global pool. +// hashers live in a global db. var hasherPool = sync.Pool{ New: func() interface{} { return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} }, } -func newHasher(cachegen, cachelimit uint16) *hasher { +func newHasher(cachegen, cachelimit uint16, onleaf LeafCallback) *hasher { h := hasherPool.Get().(*hasher) - h.cachegen, h.cachelimit = cachegen, cachelimit + h.cachegen, h.cachelimit, h.onleaf = cachegen, cachelimit, onleaf return h } @@ -51,7 +53,7 @@ func returnHasherToPool(h *hasher) { // hash collapses a node down into a hash node, also returning a copy of the // original node initialized with the computed hash to replace the original one. -func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) { +func (h *hasher) hash(n node, db *Database, force bool) (node, node, error) { // If we're not storing the node, just hashing, use available cached data if hash, dirty := n.cache(); hash != nil { if db == nil { @@ -98,7 +100,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) // hashChildren replaces the children of a node with their hashes if the encoded // size of the child is larger than a hash, returning the collapsed node as well // as a replacement for the original node with the child hashes cached in. -func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) { +func (h *hasher) hashChildren(original node, db *Database) (node, node, error) { var err error switch n := original.(type) { @@ -145,7 +147,10 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err } } -func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { +// store hashes the node n and if we have a storage layer specified, it writes +// the key/value pair to it and tracks any node->child references as well as any +// node->external trie references. +func (h *hasher) store(n node, db *Database, force bool) (node, error) { // Don't store hashes or empty nodes. if _, isHash := n.(hashNode); n == nil || isHash { return n, nil @@ -155,7 +160,6 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { if err := rlp.Encode(h.tmp, n); err != nil { panic("encode error: " + err.Error()) } - if h.tmp.Len() < 32 && !force { return n, nil // Nodes smaller than 32 bytes are stored inside their parent } @@ -167,7 +171,42 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { hash = hashNode(h.sha.Sum(nil)) } if db != nil { - return hash, db.Put(hash, h.tmp.Bytes()) + // We are pooling the trie nodes into an intermediate memory cache + db.lock.Lock() + + hash := common.BytesToHash(hash) + db.insert(hash, h.tmp.Bytes()) + + // Track all direct parent->child node references + switch n := n.(type) { + case *shortNode: + if child, ok := n.Val.(hashNode); ok { + db.reference(common.BytesToHash(child), hash) + } + case *fullNode: + for i := 0; i < 16; i++ { + if child, ok := n.Children[i].(hashNode); ok { + db.reference(common.BytesToHash(child), hash) + } + } + } + db.lock.Unlock() + + // Track external references from account->storage trie + if h.onleaf != nil { + switch n := n.(type) { + case *shortNode: + if child, ok := n.Val.(valueNode); ok { + h.onleaf(child, hash) + } + case *fullNode: + for i := 0; i < 16; i++ { + if child, ok := n.Children[i].(valueNode); ok { + h.onleaf(child, hash) + } + } + } + } } return hash, nil } diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 4808d8b0c..dce1c78b5 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -42,7 +42,7 @@ func TestIterator(t *testing.T) { all[val.k] = val.v trie.Update([]byte(val.k), []byte(val.v)) } - trie.Commit() + trie.Commit(nil) found := make(map[string]string) it := NewIterator(trie.NodeIterator(nil)) @@ -109,11 +109,18 @@ func TestNodeIteratorCoverage(t *testing.T) { } // Cross check the hashes and the database itself for hash := range hashes { - if _, err := db.Get(hash.Bytes()); err != nil { + if _, err := db.Node(hash); err != nil { t.Errorf("failed to retrieve reported node %x: %v", hash, err) } } - for _, key := range db.(*ethdb.MemDatabase).Keys() { + for hash, obj := range db.nodes { + if obj != nil && hash != (common.Hash{}) { + if _, ok := hashes[hash]; !ok { + t.Errorf("state entry not reported %x", hash) + } + } + } + for _, key := range db.diskdb.(*ethdb.MemDatabase).Keys() { if _, ok := hashes[common.BytesToHash(key)]; !ok { t.Errorf("state entry not reported %x", key) } @@ -191,13 +198,13 @@ func TestDifferenceIterator(t *testing.T) { for _, val := range testdata1 { triea.Update([]byte(val.k), []byte(val.v)) } - triea.Commit() + triea.Commit(nil) trieb := newEmpty() for _, val := range testdata2 { trieb.Update([]byte(val.k), []byte(val.v)) } - trieb.Commit() + trieb.Commit(nil) found := make(map[string]string) di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) @@ -227,13 +234,13 @@ func TestUnionIterator(t *testing.T) { for _, val := range testdata1 { triea.Update([]byte(val.k), []byte(val.v)) } - triea.Commit() + triea.Commit(nil) trieb := newEmpty() for _, val := range testdata2 { trieb.Update([]byte(val.k), []byte(val.v)) } - trieb.Commit() + trieb.Commit(nil) di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) it := NewIterator(di) @@ -278,43 +285,75 @@ func TestIteratorNoDups(t *testing.T) { } // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes. -func TestIteratorContinueAfterError(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - tr, _ := New(common.Hash{}, db) +func TestIteratorContinueAfterErrorDisk(t *testing.T) { testIteratorContinueAfterError(t, false) } +func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) } + +func testIteratorContinueAfterError(t *testing.T, memonly bool) { + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + tr, _ := New(common.Hash{}, triedb) for _, val := range testdata1 { tr.Update([]byte(val.k), []byte(val.v)) } - tr.Commit() + tr.Commit(nil) + if !memonly { + triedb.Commit(tr.Hash(), true) + } wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil) - keys := db.Keys() - t.Log("node count", wantNodeCount) + var ( + diskKeys [][]byte + memKeys []common.Hash + ) + if memonly { + memKeys = triedb.Nodes() + } else { + diskKeys = diskdb.Keys() + } for i := 0; i < 20; i++ { // Create trie that will load all nodes from DB. - tr, _ := New(tr.Hash(), db) + tr, _ := New(tr.Hash(), triedb) // Remove a random node from the database. It can't be the root node // because that one is already loaded. - var rkey []byte + var ( + rkey common.Hash + rval []byte + robj *cachedNode + ) for { - if rkey = keys[rand.Intn(len(keys))]; !bytes.Equal(rkey, tr.Hash().Bytes()) { + if memonly { + rkey = memKeys[rand.Intn(len(memKeys))] + } else { + copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))]) + } + if rkey != tr.Hash() { break } } - rval, _ := db.Get(rkey) - db.Delete(rkey) - + if memonly { + robj = triedb.nodes[rkey] + delete(triedb.nodes, rkey) + } else { + rval, _ = diskdb.Get(rkey[:]) + diskdb.Delete(rkey[:]) + } // Iterate until the error is hit. seen := make(map[string]bool) it := tr.NodeIterator(nil) checkIteratorNoDups(t, it, seen) missing, ok := it.Error().(*MissingNodeError) - if !ok || !bytes.Equal(missing.NodeHash[:], rkey) { + if !ok || missing.NodeHash != rkey { t.Fatal("didn't hit missing node, got", it.Error()) } // Add the node back and continue iteration. - db.Put(rkey, rval) + if memonly { + triedb.nodes[rkey] = robj + } else { + diskdb.Put(rkey[:], rval) + } checkIteratorNoDups(t, it, seen) if it.Error() != nil { t.Fatal("unexpected error", it.Error()) @@ -328,21 +367,41 @@ func TestIteratorContinueAfterError(t *testing.T) { // Similar to the test above, this one checks that failure to create nodeIterator at a // certain key prefix behaves correctly when Next is called. The expectation is that Next // should retry seeking before returning true for the first time. -func TestIteratorContinueAfterSeekError(t *testing.T) { +func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) { + testIteratorContinueAfterSeekError(t, false) +} +func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) { + testIteratorContinueAfterSeekError(t, true) +} + +func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) { // Commit test trie to db, then remove the node containing "bars". - db, _ := ethdb.NewMemDatabase() - ctr, _ := New(common.Hash{}, db) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + ctr, _ := New(common.Hash{}, triedb) for _, val := range testdata1 { ctr.Update([]byte(val.k), []byte(val.v)) } - root, _ := ctr.Commit() + root, _ := ctr.Commit(nil) + if !memonly { + triedb.Commit(root, true) + } barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e") - barNode, _ := db.Get(barNodeHash[:]) - db.Delete(barNodeHash[:]) - + var ( + barNodeBlob []byte + barNodeObj *cachedNode + ) + if memonly { + barNodeObj = triedb.nodes[barNodeHash] + delete(triedb.nodes, barNodeHash) + } else { + barNodeBlob, _ = diskdb.Get(barNodeHash[:]) + diskdb.Delete(barNodeHash[:]) + } // Create a new iterator that seeks to "bars". Seeking can't proceed because // the node is missing. - tr, _ := New(root, db) + tr, _ := New(root, triedb) it := tr.NodeIterator([]byte("bars")) missing, ok := it.Error().(*MissingNodeError) if !ok { @@ -350,10 +409,12 @@ func TestIteratorContinueAfterSeekError(t *testing.T) { } else if missing.NodeHash != barNodeHash { t.Fatal("wrong node missing") } - // Reinsert the missing node. - db.Put(barNodeHash[:], barNode[:]) - + if memonly { + triedb.nodes[barNodeHash] = barNodeObj + } else { + diskdb.Put(barNodeHash[:], barNodeBlob) + } // Check that iteration produces the right set of values. if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil { t.Fatal(err) diff --git a/trie/proof.go b/trie/proof.go index 5e886a259..508e4a6cf 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -22,20 +22,19 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rlp" ) -// Prove constructs a merkle proof for key. The result contains all -// encoded nodes on the path to the value at key. The value itself is -// also included in the last node and can be retrieved by verifying -// the proof. +// Prove constructs a merkle proof for key. The result contains all encoded nodes +// on the path to the value at key. The value itself is also included in the last +// node and can be retrieved by verifying the proof. // -// If the trie does not contain a value for key, the returned proof -// contains all nodes of the longest existing prefix of the key -// (at least the root node), ending with the node that proves the -// absence of the key. -func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error { +// If the trie does not contain a value for key, the returned proof contains all +// nodes of the longest existing prefix of the key (at least the root node), ending +// with the node that proves the absence of the key. +func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error { // Collect all nodes on the path to key. key = keybytesToHex(key) nodes := []node{} @@ -66,7 +65,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error { panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) } } - hasher := newHasher(0, 0) + hasher := newHasher(0, 0, nil) for i, n := range nodes { // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. @@ -89,19 +88,29 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error { return nil } -// VerifyProof checks merkle proofs. The given proof must contain the -// value for key in a trie with the given root hash. VerifyProof -// returns an error if the proof contains invalid trie nodes or the -// wrong value. +// Prove constructs a merkle proof for key. The result contains all encoded nodes +// on the path to the value at key. The value itself is also included in the last +// node and can be retrieved by verifying the proof. +// +// If the trie does not contain a value for key, the returned proof contains all +// nodes of the longest existing prefix of the key (at least the root node), ending +// with the node that proves the absence of the key. +func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error { + return t.trie.Prove(key, fromLevel, proofDb) +} + +// VerifyProof checks merkle proofs. The given proof must contain the value for +// key in a trie with the given root hash. VerifyProof returns an error if the +// proof contains invalid trie nodes or the wrong value. func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) { key = keybytesToHex(key) - wantHash := rootHash[:] + wantHash := rootHash for i := 0; ; i++ { - buf, _ := proofDb.Get(wantHash) + buf, _ := proofDb.Get(wantHash[:]) if buf == nil { - return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i + return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash), i } - n, err := decodeNode(wantHash, buf, 0) + n, err := decodeNode(wantHash[:], buf, 0) if err != nil { return nil, fmt.Errorf("bad proof node %d: %v", i, err), i } @@ -112,7 +121,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (valu return nil, nil, i case hashNode: key = keyrest - wantHash = cld + copy(wantHash[:], cld) case valueNode: return cld, nil, i + 1 } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 20c303f31..3881ee18a 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -23,10 +23,6 @@ import ( "github.com/ethereum/go-ethereum/log" ) -var secureKeyPrefix = []byte("secure-key-") - -const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash - // SecureTrie wraps a trie with key hashing. In a secure trie, all // access operations hash the key using keccak256. This prevents // calling code from creating long chains of nodes that @@ -39,25 +35,25 @@ const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash // SecureTrie is not safe for concurrent use. type SecureTrie struct { trie Trie - hashKeyBuf [secureKeyLength]byte - secKeyBuf [200]byte + hashKeyBuf [common.HashLength]byte secKeyCache map[string][]byte secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch } -// NewSecure creates a trie with an existing root node from db. +// NewSecure creates a trie with an existing root node from a backing database +// and optional intermediate in-memory node pool. // // If root is the zero hash or the sha3 hash of an empty string, the // trie is initially empty. Otherwise, New will panic if db is nil // and returns MissingNodeError if the root node cannot be found. // -// Accessing the trie loads nodes from db on demand. +// Accessing the trie loads nodes from the database or node pool on demand. // Loaded nodes are kept around until their 'cache generation' expires. // A new cache generation is created by each call to Commit. // cachelimit sets the number of past cache generations to keep. -func NewSecure(root common.Hash, db Database, cachelimit uint16) (*SecureTrie, error) { +func NewSecure(root common.Hash, db *Database, cachelimit uint16) (*SecureTrie, error) { if db == nil { - panic("NewSecure called with nil database") + panic("trie.NewSecure called without a database") } trie, err := New(root, db) if err != nil { @@ -135,7 +131,7 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { return key } - key, _ := t.trie.db.Get(t.secKey(shaKey)) + key, _ := t.trie.db.preimage(common.BytesToHash(shaKey)) return key } @@ -144,8 +140,19 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { // // Committing flushes nodes from memory. Subsequent Get calls will load nodes // from the database. -func (t *SecureTrie) Commit() (root common.Hash, err error) { - return t.CommitTo(t.trie.db) +func (t *SecureTrie) Commit(onleaf LeafCallback) (root common.Hash, err error) { + // Write all the pre-images to the actual disk database + if len(t.getSecKeyCache()) > 0 { + t.trie.db.lock.Lock() + for hk, key := range t.secKeyCache { + t.trie.db.insertPreimage(common.BytesToHash([]byte(hk)), key) + } + t.trie.db.lock.Unlock() + + t.secKeyCache = make(map[string][]byte) + } + // Commit the trie to its intermediate node database + return t.trie.Commit(onleaf) } func (t *SecureTrie) Hash() common.Hash { @@ -167,38 +174,11 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { return t.trie.NodeIterator(start) } -// CommitTo writes all nodes and the secure hash pre-images to the given database. -// Nodes are stored with their sha3 hash as the key. -// -// Committing flushes nodes from memory. Subsequent Get calls will load nodes from -// the trie's database. Calling code must ensure that the changes made to db are -// written back to the trie's attached database before using the trie. -func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { - if len(t.getSecKeyCache()) > 0 { - for hk, key := range t.secKeyCache { - if err := db.Put(t.secKey([]byte(hk)), key); err != nil { - return common.Hash{}, err - } - } - t.secKeyCache = make(map[string][]byte) - } - return t.trie.CommitTo(db) -} - -// secKey returns the database key for the preimage of key, as an ephemeral buffer. -// The caller must not hold onto the return value because it will become -// invalid on the next call to hashKey or secKey. -func (t *SecureTrie) secKey(key []byte) []byte { - buf := append(t.secKeyBuf[:0], secureKeyPrefix...) - buf = append(buf, key...) - return buf -} - // hashKey returns the hash of key as an ephemeral buffer. // The caller must not hold onto the return value because it will become // invalid on the next call to hashKey or secKey. func (t *SecureTrie) hashKey(key []byte) []byte { - h := newHasher(0, 0) + h := newHasher(0, 0, nil) h.sha.Reset() h.sha.Write(key) buf := h.sha.Sum(t.hashKeyBuf[:0]) diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index d74102e2a..aedf5a1cd 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -28,16 +28,20 @@ import ( ) func newEmptySecure() *SecureTrie { - db, _ := ethdb.NewMemDatabase() - trie, _ := NewSecure(common.Hash{}, db, 0) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + trie, _ := NewSecure(common.Hash{}, triedb, 0) return trie } // makeTestSecureTrie creates a large enough secure trie for testing. -func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) { +func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) { // Create an empty trie - db, _ := ethdb.NewMemDatabase() - trie, _ := NewSecure(common.Hash{}, db, 0) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + trie, _ := NewSecure(common.Hash{}, triedb, 0) // Fill it with some arbitrary data content := make(map[string][]byte) @@ -58,10 +62,10 @@ func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) { trie.Update(key, val) } } - trie.Commit() + trie.Commit(nil) // Return the generated trie - return db, trie, content + return triedb, trie, content } func TestSecureDelete(t *testing.T) { @@ -137,7 +141,7 @@ func TestSecureTrieConcurrency(t *testing.T) { tries[index].Update(key, val) } } - tries[index].Commit() + tries[index].Commit(nil) }(i) } // Wait for all threads to finish diff --git a/trie/sync.go b/trie/sync.go index fea10051f..b573a9f73 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) @@ -42,7 +43,7 @@ type request struct { depth int // Depth level within the trie the node is located to prioritise DFS deps int // Number of dependencies before allowed to commit this node - callback TrieSyncLeafCallback // Callback to invoke if a leaf node it reached on this branch + callback LeafCallback // Callback to invoke if a leaf node it reached on this branch } // SyncResult is a simple list to return missing nodes along with their request @@ -67,11 +68,6 @@ func newSyncMemBatch() *syncMemBatch { } } -// TrieSyncLeafCallback is a callback type invoked when a trie sync reaches a -// leaf node. It's used by state syncing to check if the leaf node requires some -// further data syncing. -type TrieSyncLeafCallback func(leaf []byte, parent common.Hash) error - // TrieSync is the main state trie synchronisation scheduler, which provides yet // unknown trie hashes to retrieve, accepts node data associated with said hashes // and reconstructs the trie step by step until all is done. @@ -83,7 +79,7 @@ type TrieSync struct { } // NewTrieSync creates a new trie data download scheduler. -func NewTrieSync(root common.Hash, database DatabaseReader, callback TrieSyncLeafCallback) *TrieSync { +func NewTrieSync(root common.Hash, database DatabaseReader, callback LeafCallback) *TrieSync { ts := &TrieSync{ database: database, membatch: newSyncMemBatch(), @@ -95,7 +91,7 @@ func NewTrieSync(root common.Hash, database DatabaseReader, callback TrieSyncLea } // AddSubTrie registers a new trie to the sync code, rooted at the designated parent. -func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback TrieSyncLeafCallback) { +func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback LeafCallback) { // Short circuit if the trie is empty or already known if root == emptyRoot { return @@ -217,7 +213,7 @@ func (s *TrieSync) Process(results []SyncResult) (bool, int, error) { // Commit flushes the data stored in the internal membatch out to persistent // storage, returning th enumber of items written and any occurred error. -func (s *TrieSync) Commit(dbw DatabaseWriter) (int, error) { +func (s *TrieSync) Commit(dbw ethdb.Putter) (int, error) { // Dump the membatch into a database dbw for i, key := range s.membatch.order { if err := dbw.Put(key[:], s.membatch.batch[key]); err != nil { diff --git a/trie/sync_test.go b/trie/sync_test.go index ec16a25bd..4a720612b 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -25,10 +25,11 @@ import ( ) // makeTestTrie create a sample test trie to test node-wise reconstruction. -func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) { +func makeTestTrie() (*Database, *Trie, map[string][]byte) { // Create an empty trie - db, _ := ethdb.NewMemDatabase() - trie, _ := New(common.Hash{}, db) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + trie, _ := New(common.Hash{}, triedb) // Fill it with some arbitrary data content := make(map[string][]byte) @@ -49,15 +50,15 @@ func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) { trie.Update(key, val) } } - trie.Commit() + trie.Commit(nil) // Return the generated trie - return db, trie, content + return triedb, trie, content } // checkTrieContents cross references a reconstructed trie with an expected data // content map. -func checkTrieContents(t *testing.T, db Database, root []byte, content map[string][]byte) { +func checkTrieContents(t *testing.T, db *Database, root []byte, content map[string][]byte) { // Check root availability and trie contents trie, err := New(common.BytesToHash(root), db) if err != nil { @@ -74,7 +75,7 @@ func checkTrieContents(t *testing.T, db Database, root []byte, content map[strin } // checkTrieConsistency checks that all nodes in a trie are indeed present. -func checkTrieConsistency(db Database, root common.Hash) error { +func checkTrieConsistency(db *Database, root common.Hash) error { // Create and iterate a trie rooted in a subnode trie, err := New(root, db) if err != nil { @@ -88,12 +89,18 @@ func checkTrieConsistency(db Database, root common.Hash) error { // Tests that an empty trie is not scheduled for syncing. func TestEmptyTrieSync(t *testing.T) { - emptyA, _ := New(common.Hash{}, nil) - emptyB, _ := New(emptyRoot, nil) + diskdbA, _ := ethdb.NewMemDatabase() + triedbA := NewDatabase(diskdbA) + + diskdbB, _ := ethdb.NewMemDatabase() + triedbB := NewDatabase(diskdbB) + + emptyA, _ := New(common.Hash{}, triedbA) + emptyB, _ := New(emptyRoot, triedbB) for i, trie := range []*Trie{emptyA, emptyB} { - db, _ := ethdb.NewMemDatabase() - if req := NewTrieSync(common.BytesToHash(trie.Root()), db, nil).Missing(1); len(req) != 0 { + diskdb, _ := ethdb.NewMemDatabase() + if req := NewTrieSync(trie.Hash(), diskdb, nil).Missing(1); len(req) != 0 { t.Errorf("test %d: content requested for empty trie: %v", i, req) } } @@ -109,14 +116,15 @@ func testIterativeTrieSync(t *testing.T, batch int) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) queue := append([]common.Hash{}, sched.Missing(batch)...) for len(queue) > 0 { results := make([]SyncResult, len(queue)) for i, hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -125,13 +133,13 @@ func testIterativeTrieSync(t *testing.T, batch int) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } queue = append(queue[:0], sched.Missing(batch)...) } // Cross check that the two tries are in sync - checkTrieContents(t, dstDb, srcTrie.Root(), srcData) + checkTrieContents(t, triedb, srcTrie.Root(), srcData) } // Tests that the trie scheduler can correctly reconstruct the state even if only @@ -141,15 +149,16 @@ func TestIterativeDelayedTrieSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) queue := append([]common.Hash{}, sched.Missing(10000)...) for len(queue) > 0 { // Sync only half of the scheduled nodes results := make([]SyncResult, len(queue)/2+1) for i, hash := range queue[:len(results)] { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -158,13 +167,13 @@ func TestIterativeDelayedTrieSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } queue = append(queue[len(results):], sched.Missing(10000)...) } // Cross check that the two tries are in sync - checkTrieContents(t, dstDb, srcTrie.Root(), srcData) + checkTrieContents(t, triedb, srcTrie.Root(), srcData) } // Tests that given a root hash, a trie can sync iteratively on a single thread, @@ -178,8 +187,9 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) queue := make(map[common.Hash]struct{}) for _, hash := range sched.Missing(batch) { @@ -189,7 +199,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) { // Fetch all the queued nodes in a random order results := make([]SyncResult, 0, len(queue)) for hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -199,7 +209,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } queue = make(map[common.Hash]struct{}) @@ -208,7 +218,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) { } } // Cross check that the two tries are in sync - checkTrieContents(t, dstDb, srcTrie.Root(), srcData) + checkTrieContents(t, triedb, srcTrie.Root(), srcData) } // Tests that the trie scheduler can correctly reconstruct the state even if only @@ -218,8 +228,9 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) queue := make(map[common.Hash]struct{}) for _, hash := range sched.Missing(10000) { @@ -229,7 +240,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) { // Sync only half of the scheduled nodes, even those in random order results := make([]SyncResult, 0, len(queue)/2+1) for hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -243,7 +254,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } for _, result := range results { @@ -254,7 +265,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) { } } // Cross check that the two tries are in sync - checkTrieContents(t, dstDb, srcTrie.Root(), srcData) + checkTrieContents(t, triedb, srcTrie.Root(), srcData) } // Tests that a trie sync will not request nodes multiple times, even if they @@ -264,8 +275,9 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) { srcDb, srcTrie, srcData := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) queue := append([]common.Hash{}, sched.Missing(0)...) requested := make(map[common.Hash]struct{}) @@ -273,7 +285,7 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) { for len(queue) > 0 { results := make([]SyncResult, len(queue)) for i, hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -287,13 +299,13 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } queue = append(queue[:0], sched.Missing(0)...) } // Cross check that the two tries are in sync - checkTrieContents(t, dstDb, srcTrie.Root(), srcData) + checkTrieContents(t, triedb, srcTrie.Root(), srcData) } // Tests that at any point in time during a sync, only complete sub-tries are in @@ -303,8 +315,9 @@ func TestIncompleteTrieSync(t *testing.T) { srcDb, srcTrie, _ := makeTestTrie() // Create a destination trie and sync with the scheduler - dstDb, _ := ethdb.NewMemDatabase() - sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + sched := NewTrieSync(srcTrie.Hash(), diskdb, nil) added := []common.Hash{} queue := append([]common.Hash{}, sched.Missing(1)...) @@ -312,7 +325,7 @@ func TestIncompleteTrieSync(t *testing.T) { // Fetch a batch of trie nodes results := make([]SyncResult, len(queue)) for i, hash := range queue { - data, err := srcDb.Get(hash.Bytes()) + data, err := srcDb.Node(hash) if err != nil { t.Fatalf("failed to retrieve node data for %x: %v", hash, err) } @@ -322,7 +335,7 @@ func TestIncompleteTrieSync(t *testing.T) { if _, index, err := sched.Process(results); err != nil { t.Fatalf("failed to process result #%d: %v", index, err) } - if index, err := sched.Commit(dstDb); err != nil { + if index, err := sched.Commit(diskdb); err != nil { t.Fatalf("failed to commit data #%d: %v", index, err) } for _, result := range results { @@ -330,7 +343,7 @@ func TestIncompleteTrieSync(t *testing.T) { } // Check that all known sub-tries in the synced trie are complete for _, root := range added { - if err := checkTrieConsistency(dstDb, root); err != nil { + if err := checkTrieConsistency(triedb, root); err != nil { t.Fatalf("trie inconsistent: %v", err) } } @@ -340,12 +353,12 @@ func TestIncompleteTrieSync(t *testing.T) { // Sanity check that removing any node from the database is detected for _, node := range added[1:] { key := node.Bytes() - value, _ := dstDb.Get(key) + value, _ := diskdb.Get(key) - dstDb.Delete(key) - if err := checkTrieConsistency(dstDb, added[0]); err == nil { + diskdb.Delete(key) + if err := checkTrieConsistency(triedb, added[0]); err == nil { t.Fatalf("trie inconsistency not caught, missing: %x", key) } - dstDb.Put(key, value) + diskdb.Put(key, value) } } diff --git a/trie/trie.go b/trie/trie.go index 8fe98d835..e37a1ae10 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -22,16 +22,17 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto/sha3" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/rcrowley/go-metrics" ) var ( - // This is the known root hash of an empty trie. + // emptyRoot is the known root hash of an empty trie. emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - // This is the known hash of an empty state trie entry. - emptyState common.Hash + + // emptyState is the known hash of an empty state trie entry. + emptyState = crypto.Keccak256Hash(nil) ) var ( @@ -53,29 +54,10 @@ func CacheUnloads() int64 { return cacheUnloadCounter.Count() } -func init() { - sha3.NewKeccak256().Sum(emptyState[:0]) -} - -// Database must be implemented by backing stores for the trie. -type Database interface { - DatabaseReader - DatabaseWriter -} - -// DatabaseReader wraps the Get method of a backing store for the trie. -type DatabaseReader interface { - Get(key []byte) (value []byte, err error) - Has(key []byte) (bool, error) -} - -// DatabaseWriter wraps the Put method of a backing store for the trie. -type DatabaseWriter interface { - // Put stores the mapping key->value in the database. - // Implementations must not hold onto the value bytes, the trie - // will reuse the slice across calls to Put. - Put(key, value []byte) error -} +// LeafCallback is a callback type invoked when a trie operation reaches a leaf +// node. It's used by state sync and commit to allow handling external references +// between account and storage tries. +type LeafCallback func(leaf []byte, parent common.Hash) error // Trie is a Merkle Patricia Trie. // The zero value is an empty trie with no database. @@ -83,8 +65,8 @@ type DatabaseWriter interface { // // Trie is not safe for concurrent use. type Trie struct { + db *Database root node - db Database originalRoot common.Hash // Cache generation values. @@ -111,12 +93,15 @@ func (t *Trie) newFlag() nodeFlag { // trie is initially empty and does not require a database. Otherwise, // New will panic if db is nil and returns a MissingNodeError if root does // not exist in the database. Accessing the trie loads nodes from db on demand. -func New(root common.Hash, db Database) (*Trie, error) { - trie := &Trie{db: db, originalRoot: root} +func New(root common.Hash, db *Database) (*Trie, error) { + if db == nil { + panic("trie.New called without a database") + } + trie := &Trie{ + db: db, + originalRoot: root, + } if (root != common.Hash{}) && root != emptyRoot { - if db == nil { - panic("trie.New: cannot use existing root without a database") - } rootnode, err := trie.resolveHash(root[:], nil) if err != nil { return nil, err @@ -447,12 +432,13 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) { func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { cacheMissCounter.Inc(1) - enc, err := t.db.Get(n) + hash := common.BytesToHash(n) + + enc, err := t.db.Node(hash) if err != nil || enc == nil { - return nil, &MissingNodeError{NodeHash: common.BytesToHash(n), Path: prefix} + return nil, &MissingNodeError{NodeHash: hash, Path: prefix} } - dec := mustDecodeNode(n, enc, t.cachegen) - return dec, nil + return mustDecodeNode(n, enc, t.cachegen), nil } // Root returns the root hash of the trie. @@ -462,32 +448,18 @@ func (t *Trie) Root() []byte { return t.Hash().Bytes() } // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { - hash, cached, _ := t.hashRoot(nil) + hash, cached, _ := t.hashRoot(nil, nil) t.root = cached return common.BytesToHash(hash.(hashNode)) } -// Commit writes all nodes to the trie's database. -// Nodes are stored with their sha3 hash as the key. -// -// Committing flushes nodes from memory. -// Subsequent Get calls will load nodes from the database. -func (t *Trie) Commit() (root common.Hash, err error) { +// Commit writes all nodes to the trie's memory database, tracking the internal +// and external (for account tries) references. +func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) { if t.db == nil { - panic("Commit called on trie with nil database") + panic("commit called on trie with nil database") } - return t.CommitTo(t.db) -} - -// CommitTo writes all nodes to the given database. -// Nodes are stored with their sha3 hash as the key. -// -// Committing flushes nodes from memory. Subsequent Get calls will -// load nodes from the trie's database. Calling code must ensure that -// the changes made to db are written back to the trie's attached -// database before using the trie. -func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { - hash, cached, err := t.hashRoot(db) + hash, cached, err := t.hashRoot(t.db, onleaf) if err != nil { return common.Hash{}, err } @@ -496,11 +468,11 @@ func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { return common.BytesToHash(hash.(hashNode)), nil } -func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) { +func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node, error) { if t.root == nil { return hashNode(emptyRoot.Bytes()), nil, nil } - h := newHasher(t.cachegen, t.cachelimit) + h := newHasher(t.cachegen, t.cachelimit, onleaf) defer returnHasherToPool(h) return h.hash(t.root, db, true) } diff --git a/trie/trie_test.go b/trie/trie_test.go index 1e28c3bc4..997222628 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -43,8 +43,8 @@ func init() { // Used for testing func newEmpty() *Trie { - db, _ := ethdb.NewMemDatabase() - trie, _ := New(common.Hash{}, db) + diskdb, _ := ethdb.NewMemDatabase() + trie, _ := New(common.Hash{}, NewDatabase(diskdb)) return trie } @@ -68,8 +68,8 @@ func TestNull(t *testing.T) { } func TestMissingRoot(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), db) + diskdb, _ := ethdb.NewMemDatabase() + trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(diskdb)) if trie != nil { t.Error("New returned non-nil trie for invalid root") } @@ -78,70 +78,75 @@ func TestMissingRoot(t *testing.T) { } } -func TestMissingNode(t *testing.T) { - db, _ := ethdb.NewMemDatabase() - trie, _ := New(common.Hash{}, db) +func TestMissingNodeDisk(t *testing.T) { testMissingNode(t, false) } +func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) } + +func testMissingNode(t *testing.T, memonly bool) { + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + trie, _ := New(common.Hash{}, triedb) updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer") updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf") - root, _ := trie.Commit() + root, _ := trie.Commit(nil) + if !memonly { + triedb.Commit(root, true) + } - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err := trie.TryGet([]byte("120000")) if err != nil { t.Errorf("Unexpected error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err = trie.TryGet([]byte("120099")) if err != nil { t.Errorf("Unexpected error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err = trie.TryGet([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) if err != nil { t.Errorf("Unexpected error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) err = trie.TryDelete([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } - db.Delete(common.FromHex("e1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")) + hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9") + if memonly { + delete(triedb.nodes, hash) + } else { + diskdb.Delete(hash[:]) + } - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err = trie.TryGet([]byte("120000")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err = trie.TryGet([]byte("120099")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) _, err = trie.TryGet([]byte("123456")) if err != nil { t.Errorf("Unexpected error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) err = trie.TryUpdate([]byte("120099"), []byte("zxcv")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) } - - trie, _ = New(root, db) + trie, _ = New(root, triedb) err = trie.TryDelete([]byte("123456")) if _, ok := err.(*MissingNodeError); !ok { t.Errorf("Wrong error: %v", err) @@ -165,7 +170,7 @@ func TestInsert(t *testing.T) { updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") - root, err := trie.Commit() + root, err := trie.Commit(nil) if err != nil { t.Fatalf("commit error: %v", err) } @@ -194,7 +199,7 @@ func TestGet(t *testing.T) { if i == 1 { return } - trie.Commit() + trie.Commit(nil) } } @@ -263,7 +268,7 @@ func TestReplication(t *testing.T) { for _, val := range vals { updateString(trie, val.k, val.v) } - exp, err := trie.Commit() + exp, err := trie.Commit(nil) if err != nil { t.Fatalf("commit error: %v", err) } @@ -278,7 +283,7 @@ func TestReplication(t *testing.T) { t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v) } } - hash, err := trie2.Commit() + hash, err := trie2.Commit(nil) if err != nil { t.Fatalf("commit error: %v", err) } @@ -314,7 +319,7 @@ func TestLargeValue(t *testing.T) { } type countingDB struct { - Database + ethdb.Database gets map[string]int } @@ -332,19 +337,20 @@ func TestCacheUnload(t *testing.T) { key2 := "---some other branch" updateString(trie, key1, "this is the branch of key1.") updateString(trie, key2, "this is the branch of key2.") - root, _ := trie.Commit() + + root, _ := trie.Commit(nil) + trie.db.Commit(root, true) // Commit the trie repeatedly and access key1. // The branch containing it is loaded from DB exactly two times: // in the 0th and 6th iteration. - db := &countingDB{Database: trie.db, gets: make(map[string]int)} - trie, _ = New(root, db) + db := &countingDB{Database: trie.db.diskdb, gets: make(map[string]int)} + trie, _ = New(root, NewDatabase(db)) trie.SetCacheLimit(5) for i := 0; i < 12; i++ { getString(trie, key1) - trie.Commit() + trie.Commit(nil) } - // Check that it got loaded two times. for dbkey, count := range db.gets { if count != 2 { @@ -407,8 +413,10 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value { } func runRandTest(rt randTest) bool { - db, _ := ethdb.NewMemDatabase() - tr, _ := New(common.Hash{}, db) + diskdb, _ := ethdb.NewMemDatabase() + triedb := NewDatabase(diskdb) + + tr, _ := New(common.Hash{}, triedb) values := make(map[string]string) // tracks content of the trie for i, step := range rt { @@ -426,23 +434,23 @@ func runRandTest(rt randTest) bool { rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want) } case opCommit: - _, rt[i].err = tr.Commit() + _, rt[i].err = tr.Commit(nil) case opHash: tr.Hash() case opReset: - hash, err := tr.Commit() + hash, err := tr.Commit(nil) if err != nil { rt[i].err = err return false } - newtr, err := New(hash, db) + newtr, err := New(hash, triedb) if err != nil { rt[i].err = err return false } tr = newtr case opItercheckhash: - checktr, _ := New(common.Hash{}, nil) + checktr, _ := New(common.Hash{}, triedb) it := NewIterator(tr.NodeIterator(nil)) for it.Next() { checktr.Update(it.Key, it.Value) @@ -524,7 +532,7 @@ func benchGet(b *testing.B, commit bool) { } binary.LittleEndian.PutUint64(k, benchElemCount/2) if commit { - trie.Commit() + trie.Commit(nil) } b.ResetTimer() @@ -534,7 +542,7 @@ func benchGet(b *testing.B, commit bool) { b.StopTimer() if commit { - ldb := trie.db.(*ethdb.LDBDatabase) + ldb := trie.db.diskdb.(*ethdb.LDBDatabase) ldb.Close() os.RemoveAll(ldb.Path()) } @@ -585,16 +593,16 @@ func BenchmarkHash(b *testing.B) { trie.Hash() } -func tempDB() (string, Database) { +func tempDB() (string, *Database) { dir, err := ioutil.TempDir("", "trie-bench") if err != nil { panic(fmt.Sprintf("can't create temporary directory: %v", err)) } - db, err := ethdb.NewLDBDatabase(dir, 256, 0) + diskdb, err := ethdb.NewLDBDatabase(dir, 256, 0) if err != nil { panic(fmt.Sprintf("can't create temporary database: %v", err)) } - return dir, db + return dir, NewDatabase(diskdb) } func getString(trie *Trie, k string) []byte {