From 5ddedd2f83729cc91b6e52858f9b3ff58888836d Mon Sep 17 00:00:00 2001
From: rjl493456442 <garyrong0905@gmail.com>
Date: Wed, 7 Sep 2022 15:08:56 +0800
Subject: [PATCH] core, light, trie: remove DiskDB function from trie database
 (#25690)

---
 core/state/database.go               | 14 ++++++++++++--
 core/state/iterator_test.go          |  3 +--
 core/state/snapshot/generate_test.go | 24 ++++++++++++------------
 core/state/statedb.go                |  2 +-
 core/state/sync_test.go              |  4 ++--
 light/trie.go                        |  4 ++++
 trie/database.go                     |  5 -----
 7 files changed, 32 insertions(+), 24 deletions(-)

diff --git a/core/state/database.go b/core/state/database.go
index 96b6bcfe6..9b4fd8946 100644
--- a/core/state/database.go
+++ b/core/state/database.go
@@ -54,6 +54,9 @@ type Database interface {
 	// ContractCodeSize retrieves a particular contracts code's size.
 	ContractCodeSize(addrHash, codeHash common.Hash) (int, error)
 
+	// DiskDB returns the underlying key-value disk database.
+	DiskDB() ethdb.KeyValueStore
+
 	// TrieDB retrieves the low level trie database used for data storage.
 	TrieDB() *trie.Database
 }
@@ -130,6 +133,7 @@ func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database {
 	csc, _ := lru.New(codeSizeCacheSize)
 	return &cachingDB{
 		db:            trie.NewDatabaseWithConfig(db, config),
+		disk:          db,
 		codeSizeCache: csc,
 		codeCache:     fastcache.New(codeCacheSize),
 	}
@@ -137,6 +141,7 @@ func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database {
 
 type cachingDB struct {
 	db            *trie.Database
+	disk          ethdb.KeyValueStore
 	codeSizeCache *lru.Cache
 	codeCache     *fastcache.Cache
 }
@@ -174,7 +179,7 @@ func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error
 	if code := db.codeCache.Get(nil, codeHash.Bytes()); len(code) > 0 {
 		return code, nil
 	}
-	code := rawdb.ReadCode(db.db.DiskDB(), codeHash)
+	code := rawdb.ReadCode(db.disk, codeHash)
 	if len(code) > 0 {
 		db.codeCache.Set(codeHash.Bytes(), code)
 		db.codeSizeCache.Add(codeHash, len(code))
@@ -190,7 +195,7 @@ func (db *cachingDB) ContractCodeWithPrefix(addrHash, codeHash common.Hash) ([]b
 	if code := db.codeCache.Get(nil, codeHash.Bytes()); len(code) > 0 {
 		return code, nil
 	}
-	code := rawdb.ReadCodeWithPrefix(db.db.DiskDB(), codeHash)
+	code := rawdb.ReadCodeWithPrefix(db.disk, codeHash)
 	if len(code) > 0 {
 		db.codeCache.Set(codeHash.Bytes(), code)
 		db.codeSizeCache.Add(codeHash, len(code))
@@ -208,6 +213,11 @@ func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, erro
 	return len(code), err
 }
 
+// DiskDB returns the underlying key-value disk database.
+func (db *cachingDB) DiskDB() ethdb.KeyValueStore {
+	return db.disk
+}
+
 // TrieDB retrieves any intermediate trie-node caching layer.
 func (db *cachingDB) TrieDB() *trie.Database {
 	return db.db
diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go
index d1afe9ca3..f93375126 100644
--- a/core/state/iterator_test.go
+++ b/core/state/iterator_test.go
@@ -21,7 +21,6 @@ import (
 	"testing"
 
 	"github.com/ethereum/go-ethereum/common"
-	"github.com/ethereum/go-ethereum/ethdb"
 )
 
 // Tests that the node iterator indeed walks over the entire database contents.
@@ -55,7 +54,7 @@ func TestNodeIteratorCoverage(t *testing.T) {
 			t.Errorf("state entry not reported %x", hash)
 		}
 	}
-	it := db.TrieDB().DiskDB().(ethdb.Database).NewIterator(nil, nil)
+	it := db.DiskDB().NewIterator(nil, nil)
 	for it.Next() {
 		key := it.Key()
 		if bytes.HasPrefix(key, []byte("secure-key-")) {
diff --git a/core/state/snapshot/generate_test.go b/core/state/snapshot/generate_test.go
index 5e5ded61e..447ca80ca 100644
--- a/core/state/snapshot/generate_test.go
+++ b/core/state/snapshot/generate_test.go
@@ -491,12 +491,12 @@ func TestGenerateWithExtraAccounts(t *testing.T) {
 
 		// Identical in the snap
 		key := hashData([]byte("acc-1"))
-		rawdb.WriteAccountSnapshot(helper.triedb.DiskDB(), key, val)
-		rawdb.WriteStorageSnapshot(helper.triedb.DiskDB(), key, hashData([]byte("key-1")), []byte("val-1"))
-		rawdb.WriteStorageSnapshot(helper.triedb.DiskDB(), key, hashData([]byte("key-2")), []byte("val-2"))
-		rawdb.WriteStorageSnapshot(helper.triedb.DiskDB(), key, hashData([]byte("key-3")), []byte("val-3"))
-		rawdb.WriteStorageSnapshot(helper.triedb.DiskDB(), key, hashData([]byte("key-4")), []byte("val-4"))
-		rawdb.WriteStorageSnapshot(helper.triedb.DiskDB(), key, hashData([]byte("key-5")), []byte("val-5"))
+		rawdb.WriteAccountSnapshot(helper.diskdb, key, val)
+		rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-1")), []byte("val-1"))
+		rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-2")), []byte("val-2"))
+		rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-3")), []byte("val-3"))
+		rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-4")), []byte("val-4"))
+		rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-5")), []byte("val-5"))
 	}
 	{
 		// Account two exists only in the snapshot
@@ -508,15 +508,15 @@ func TestGenerateWithExtraAccounts(t *testing.T) {
 		acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: emptyCode.Bytes()}
 		val, _ := rlp.EncodeToBytes(acc)
 		key := hashData([]byte("acc-2"))
-		rawdb.WriteAccountSnapshot(helper.triedb.DiskDB(), key, val)
-		rawdb.WriteStorageSnapshot(helper.triedb.DiskDB(), key, hashData([]byte("b-key-1")), []byte("b-val-1"))
-		rawdb.WriteStorageSnapshot(helper.triedb.DiskDB(), key, hashData([]byte("b-key-2")), []byte("b-val-2"))
-		rawdb.WriteStorageSnapshot(helper.triedb.DiskDB(), key, hashData([]byte("b-key-3")), []byte("b-val-3"))
+		rawdb.WriteAccountSnapshot(helper.diskdb, key, val)
+		rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-1")), []byte("b-val-1"))
+		rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-2")), []byte("b-val-2"))
+		rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-3")), []byte("b-val-3"))
 	}
 	root := helper.Commit()
 
 	// To verify the test: If we now inspect the snap db, there should exist extraneous storage items
-	if data := rawdb.ReadStorageSnapshot(helper.triedb.DiskDB(), hashData([]byte("acc-2")), hashData([]byte("b-key-1"))); data == nil {
+	if data := rawdb.ReadStorageSnapshot(helper.diskdb, hashData([]byte("acc-2")), hashData([]byte("b-key-1"))); data == nil {
 		t.Fatalf("expected snap storage to exist")
 	}
 	snap := generateSnapshot(helper.diskdb, helper.triedb, 16, root)
@@ -534,7 +534,7 @@ func TestGenerateWithExtraAccounts(t *testing.T) {
 	snap.genAbort <- stop
 	<-stop
 	// If we now inspect the snap db, there should exist no extraneous storage items
-	if data := rawdb.ReadStorageSnapshot(helper.triedb.DiskDB(), hashData([]byte("acc-2")), hashData([]byte("b-key-1"))); data != nil {
+	if data := rawdb.ReadStorageSnapshot(helper.diskdb, hashData([]byte("acc-2")), hashData([]byte("b-key-1"))); data != nil {
 		t.Fatalf("expected slot to be removed, got %v", string(data))
 	}
 }
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 50eee8183..a649e0bd1 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -908,7 +908,7 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) {
 		storageTrieNodes int
 		nodes            = trie.NewMergedNodeSet()
 	)
-	codeWriter := s.db.TrieDB().DiskDB().NewBatch()
+	codeWriter := s.db.DiskDB().NewBatch()
 	for addr := range s.stateObjectsDirty {
 		if obj := s.stateObjects[addr]; !obj.deleted {
 			// Write any contract code associated with the state object
diff --git a/core/state/sync_test.go b/core/state/sync_test.go
index 3d9fe556d..d16c7ce73 100644
--- a/core/state/sync_test.go
+++ b/core/state/sync_test.go
@@ -100,7 +100,7 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou
 }
 
 // checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present.
-func checkTrieConsistency(db ethdb.Database, root common.Hash) error {
+func checkTrieConsistency(db ethdb.KeyValueStore, root common.Hash) error {
 	if v, _ := db.Get(root[:]); v == nil {
 		return nil // Consider a non existent state consistent.
 	}
@@ -553,7 +553,7 @@ func TestIncompleteStateSync(t *testing.T) {
 		}
 	}
 	isCode[common.BytesToHash(emptyCodeHash)] = struct{}{}
-	checkTrieConsistency(srcDb.TrieDB().DiskDB().(ethdb.Database), srcRoot)
+	checkTrieConsistency(srcDb.DiskDB(), srcRoot)
 
 	// Create a destination state and sync with the scheduler
 	dstDb := rawdb.NewMemoryDatabase()
diff --git a/light/trie.go b/light/trie.go
index b88265e87..0f2e38625 100644
--- a/light/trie.go
+++ b/light/trie.go
@@ -96,6 +96,10 @@ func (db *odrDatabase) TrieDB() *trie.Database {
 	return nil
 }
 
+func (db *odrDatabase) DiskDB() ethdb.KeyValueStore {
+	panic("not implemented")
+}
+
 type odrTrie struct {
 	db   *odrDatabase
 	id   *TrieID
diff --git a/trie/database.go b/trie/database.go
index 8d426c252..cca2bb085 100644
--- a/trie/database.go
+++ b/trie/database.go
@@ -304,11 +304,6 @@ func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database
 	return db
 }
 
-// DiskDB retrieves the persistent storage backing the trie database.
-func (db *Database) DiskDB() ethdb.KeyValueStore {
-	return db.diskdb
-}
-
 // insert inserts a simplified trie node into the memory database.
 // All nodes inserted by this function will be reference tracked
 // and in theory should only used for **trie nodes** insertion.