From 49177ee0ab50f6207519be22d5e0bf90048cd47f Mon Sep 17 00:00:00 2001 From: Roy Crihfield Date: Tue, 11 Apr 2023 17:42:11 +0800 Subject: [PATCH] Refactor redundant DB code --- config.go | 90 ----------------- pgx.go | 13 +-- sqlx.go | 20 +--- state_database.go | 20 +--- statedb_test.go | 248 +++++----------------------------------------- 5 files changed, 32 insertions(+), 359 deletions(-) delete mode 100644 config.go diff --git a/config.go b/config.go deleted file mode 100644 index aba21df..0000000 --- a/config.go +++ /dev/null @@ -1,90 +0,0 @@ -package ipld_eth_statedb - -import ( - "context" - "fmt" - "time" - - "github.com/jackc/pgx/v4/pgxpool" - "github.com/jmoiron/sqlx" - - _ "github.com/lib/pq" -) - -type Config struct { - Hostname string - Port int - DatabaseName string - Username string - Password string - - ConnTimeout time.Duration - MaxConns int - MinConns int - MaxConnLifetime time.Duration - MaxConnIdleTime time.Duration - MaxIdle int -} - -// DbConnectionString constructs and returns the connection string from the config (for sqlx driver) -func (c Config) DbConnectionString() string { - if len(c.Username) > 0 && len(c.Password) > 0 { - return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable", - c.Username, c.Password, c.Hostname, c.Port, c.DatabaseName) - } - if len(c.Username) > 0 && len(c.Password) == 0 { - return fmt.Sprintf("postgresql://%s@%s:%d/%s?sslmode=disable", - c.Username, c.Hostname, c.Port, c.DatabaseName) - } - return fmt.Sprintf("postgresql://%s:%d/%s?sslmode=disable", c.Hostname, c.Port, c.DatabaseName) -} - -// NewPGXPool returns a new pgx conn pool -func NewPGXPool(ctx context.Context, config Config) (*pgxpool.Pool, error) { - pgConf, err := makePGXConfig(config) - if err != nil { - return nil, err - } - return pgxpool.ConnectConfig(ctx, pgConf) -} - -// NewSQLXPool returns a new sqlx conn pool -func NewSQLXPool(ctx context.Context, config Config) (*sqlx.DB, error) { - db, err := sqlx.ConnectContext(ctx, "postgres", config.DbConnectionString()) - if err != nil { - return nil, err - } - return db, nil -} - -// makePGXConfig creates a pgxpool.Config from the provided Config -func makePGXConfig(config Config) (*pgxpool.Config, error) { - conf, err := pgxpool.ParseConfig("") - if err != nil { - return nil, err - } - - conf.ConnConfig.Config.Host = config.Hostname - conf.ConnConfig.Config.Port = uint16(config.Port) - conf.ConnConfig.Config.Database = config.DatabaseName - conf.ConnConfig.Config.User = config.Username - conf.ConnConfig.Config.Password = config.Password - - if config.ConnTimeout != 0 { - conf.ConnConfig.Config.ConnectTimeout = config.ConnTimeout - } - if config.MaxConns != 0 { - conf.MaxConns = int32(config.MaxConns) - } - if config.MinConns != 0 { - conf.MinConns = int32(config.MinConns) - } - if config.MaxConnLifetime != 0 { - conf.MaxConnLifetime = config.MaxConnLifetime - } - if config.MaxConnIdleTime != 0 { - conf.MaxConnIdleTime = config.MaxConnIdleTime - } - - return conf, nil -} diff --git a/pgx.go b/pgx.go index c780d20..e60a6d5 100644 --- a/pgx.go +++ b/pgx.go @@ -15,18 +15,9 @@ type PGXDriver struct { db *pgxpool.Pool } -// NewPGXDriver returns a new pgx driver for Postgres -func NewPGXDriver(ctx context.Context, config Config) (*PGXDriver, error) { - db, err := NewPGXPool(ctx, config) - if err != nil { - return nil, err - } - return &PGXDriver{ctx: ctx, db: db}, nil -} - // NewPGXDriverFromPool returns a new pgx driver for Postgres -func NewPGXDriverFromPool(ctx context.Context, db *pgxpool.Pool) (*PGXDriver, error) { - return &PGXDriver{ctx: ctx, db: db}, nil +func NewPGXDriverFromPool(ctx context.Context, db *pgxpool.Pool) *PGXDriver { + return &PGXDriver{ctx: ctx, db: db} } // QueryRow satisfies sql.Database diff --git a/sqlx.go b/sqlx.go index 92f4c03..2eed275 100644 --- a/sqlx.go +++ b/sqlx.go @@ -14,25 +14,9 @@ type SQLXDriver struct { db *sqlx.DB } -// NewSQLXDriver returns a new sqlx driver for Postgres -func NewSQLXDriver(ctx context.Context, config Config) (*SQLXDriver, error) { - db, err := NewSQLXPool(ctx, config) - if err != nil { - return nil, err - } - if config.MaxConns > 0 { - db.SetMaxOpenConns(config.MaxConns) - } - if config.MaxConnLifetime > 0 { - db.SetConnMaxLifetime(config.MaxConnLifetime) - } - db.SetMaxIdleConns(config.MaxIdle) - return &SQLXDriver{ctx: ctx, db: db}, nil -} - // NewSQLXDriverFromPool returns a new sqlx driver for Postgres -func NewSQLXDriverFromPool(ctx context.Context, db *sqlx.DB) (*SQLXDriver, error) { - return &SQLXDriver{ctx: ctx, db: db}, nil +func NewSQLXDriverFromPool(ctx context.Context, db *sqlx.DB) *SQLXDriver { + return &SQLXDriver{ctx: ctx, db: db} } // QueryRow satisfies sql.Database diff --git a/state_database.go b/state_database.go index ee115e2..cb2c7b4 100644 --- a/state_database.go +++ b/state_database.go @@ -8,8 +8,6 @@ import ( "github.com/VictoriaMetrics/fastcache" lru "github.com/hashicorp/golang-lru" - "github.com/jackc/pgx/v4/pgxpool" - "github.com/jmoiron/sqlx" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -48,24 +46,14 @@ type stateDatabase struct { codeCache *fastcache.Cache } -// NewStateDatabaseWithPgxPool returns a new Database implementation using the provided postgres connection pool -func NewStateDatabaseWithPgxPool(pgDb *pgxpool.Pool) (*stateDatabase, error) { +// NewStateDatabase returns a new Database implementation using the passed parameters +func NewStateDatabase(db Database) *stateDatabase { csc, _ := lru.New(codeSizeCacheSize) return &stateDatabase{ - db: NewPostgresDB(&PGXDriver{db: pgDb}), + db: db, codeSizeCache: csc, codeCache: fastcache.New(codeCacheSize), - }, nil -} - -// NewStateDatabaseWithSqlxPool returns a new Database implementation using the passed parameters -func NewStateDatabaseWithSqlxPool(db *sqlx.DB) (*stateDatabase, error) { - csc, _ := lru.New(codeSizeCacheSize) - return &stateDatabase{ - db: NewPostgresDB(&SQLXDriver{db: db}), - codeSizeCache: csc, - codeCache: fastcache.New(codeCacheSize), - }, nil + } } // ContractCode satisfies Database, it returns the contract code for a given codehash diff --git a/statedb_test.go b/statedb_test.go index 335522e..bb097b9 100644 --- a/statedb_test.go +++ b/statedb_test.go @@ -3,8 +3,6 @@ package ipld_eth_statedb_test import ( "context" "math/big" - "os" - "strconv" "testing" "github.com/lib/pq" @@ -15,6 +13,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" "github.com/ethereum/go-ethereum/statediff/indexer/ipld" statedb "github.com/cerc-io/ipld-eth-statedb" @@ -86,17 +85,12 @@ var ( ) func TestPGXSuite(t *testing.T) { - testConfig, err := getTestConfig() + testConfig, err := postgres.DefaultConfig.WithEnv() require.NoError(t, err) - pool, err := statedb.NewPGXPool(testCtx, testConfig) + pool, err := postgres.ConnectPGX(testCtx, testConfig) if err != nil { t.Fatal(err) } - driver, err := statedb.NewPGXDriverFromPool(context.Background(), pool) - if err != nil { - t.Fatal(err) - } - database := statedb.NewPostgresDB(driver) t.Cleanup(func() { tx, err := pool.Begin(testCtx) require.NoError(t, err) @@ -112,213 +106,23 @@ func TestPGXSuite(t *testing.T) { } require.NoError(t, tx.Commit(testCtx)) }) - require.NoError(t, insertHeaderCID(database, BlockHash.String(), BlockParentHash.String(), BlockNumber.Uint64())) - require.NoError(t, insertHeaderCID(database, BlockHash2.String(), BlockHash.String(), BlockNumber2)) - require.NoError(t, insertHeaderCID(database, BlockHash3.String(), BlockHash2.String(), BlockNumber3)) - require.NoError(t, insertHeaderCID(database, BlockHash4.String(), BlockHash3.String(), BlockNumber4)) - require.NoError(t, insertHeaderCID(database, NonCanonicalHash4.String(), BlockHash3.String(), BlockNumber4)) - require.NoError(t, insertHeaderCID(database, BlockHash5.String(), BlockHash4.String(), BlockNumber5)) - require.NoError(t, insertHeaderCID(database, NonCanonicalHash5.String(), NonCanonicalHash4.String(), BlockNumber5)) - require.NoError(t, insertHeaderCID(database, BlockHash6.String(), BlockHash5.String(), BlockNumber6)) - require.NoError(t, insertStateCID(database, stateModel{ - BlockNumber: BlockNumber.Uint64(), - BlockHash: BlockHash.String(), - LeafKey: AccountLeafKey.String(), - CID: AccountCID.String(), - Diff: true, - Balance: Account.Balance.Uint64(), - Nonce: Account.Nonce, - CodeHash: AccountCodeHash.String(), - StorageRoot: Account.Root.String(), - Removed: false, - })) - require.NoError(t, insertStateCID(database, stateModel{ - BlockNumber: BlockNumber4, - BlockHash: NonCanonicalHash4.String(), - LeafKey: AccountLeafKey.String(), - CID: AccountCID.String(), - Diff: true, - Balance: big.NewInt(123).Uint64(), - Nonce: Account.Nonce, - CodeHash: AccountCodeHash.String(), - StorageRoot: Account.Root.String(), - Removed: false, - })) - require.NoError(t, insertStateCID(database, stateModel{ - BlockNumber: BlockNumber5, - BlockHash: BlockHash5.String(), - LeafKey: AccountLeafKey.String(), - CID: RemovedNodeStateCID, - Diff: true, - Removed: true, - })) - require.NoError(t, insertStorageCID(database, storageModel{ - BlockNumber: BlockNumber.Uint64(), - BlockHash: BlockHash.String(), - LeafKey: AccountLeafKey.String(), - StorageLeafKey: StorageLeafKey.String(), - StorageCID: StorageCID.String(), - Diff: true, - Value: StoredValueRLP, - Removed: false, - })) - require.NoError(t, insertStorageCID(database, storageModel{ - BlockNumber: BlockNumber2, - BlockHash: BlockHash2.String(), - LeafKey: AccountLeafKey.String(), - StorageLeafKey: StorageLeafKey.String(), - StorageCID: RemovedNodeStorageCID, - Diff: true, - Value: []byte{}, - Removed: true, - })) - require.NoError(t, insertStorageCID(database, storageModel{ - BlockNumber: BlockNumber3, - BlockHash: BlockHash3.String(), - LeafKey: AccountLeafKey.String(), - StorageLeafKey: StorageLeafKey.String(), - StorageCID: StorageCID.String(), - Diff: true, - Value: StoredValueRLP2, - Removed: false, - })) - require.NoError(t, insertStorageCID(database, storageModel{ - BlockNumber: BlockNumber4, - BlockHash: NonCanonicalHash4.String(), - LeafKey: AccountLeafKey.String(), - StorageLeafKey: StorageLeafKey.String(), - StorageCID: StorageCID.String(), - Diff: true, - Value: NonCanonStoredValueRLP, - Removed: false, - })) - require.NoError(t, insertContractCode(database)) - db, err := statedb.NewStateDatabaseWithPgxPool(pool) + driver := statedb.NewPGXDriverFromPool(context.Background(), pool) + database := statedb.NewPostgresDB(driver) + insertSuiteData(t, database) + + db := statedb.NewStateDatabase(database) require.NoError(t, err) - - t.Run("Database", func(t *testing.T) { - size, err := db.ContractCodeSize(AccountCodeHash) - require.NoError(t, err) - require.Equal(t, len(AccountCode), size) - - code, err := db.ContractCode(AccountCodeHash) - require.NoError(t, err) - require.Equal(t, AccountCode, code) - - acct, err := db.StateAccount(AccountLeafKey, BlockHash) - require.NoError(t, err) - require.Equal(t, &Account, acct) - - acct2, err := db.StateAccount(AccountLeafKey, BlockHash2) - require.NoError(t, err) - require.Equal(t, &Account, acct2) - - acct3, err := db.StateAccount(AccountLeafKey, BlockHash3) - require.NoError(t, err) - require.Equal(t, &Account, acct3) - - // check that we don't get the non-canonical account - acct4, err := db.StateAccount(AccountLeafKey, BlockHash4) - require.NoError(t, err) - require.Equal(t, &Account, acct4) - - acct5, err := db.StateAccount(AccountLeafKey, BlockHash5) - require.NoError(t, err) - require.Nil(t, acct5) - - acct6, err := db.StateAccount(AccountLeafKey, BlockHash6) - require.NoError(t, err) - require.Nil(t, acct6) - - val, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash) - require.NoError(t, err) - require.Equal(t, StoredValueRLP, val) - - val2, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash2) - require.NoError(t, err) - require.Nil(t, val2) - - val3, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash3) - require.NoError(t, err) - require.Equal(t, StoredValueRLP2, val3) - - // this checks that we don't get the non-canonical result - val4, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash4) - require.NoError(t, err) - require.Equal(t, StoredValueRLP2, val4) - - // this checks that when the entire account was deleted, we return nil result for storage slot - val5, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash5) - require.NoError(t, err) - require.Nil(t, val5) - - val6, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash6) - require.NoError(t, err) - require.Nil(t, val6) - }) - - t.Run("StateDB", func(t *testing.T) { - sdb, err := statedb.New(BlockHash, db) - require.NoError(t, err) - - checkAccountUnchanged := func() { - require.Equal(t, Account.Balance, sdb.GetBalance(AccountAddress)) - require.Equal(t, Account.Nonce, sdb.GetNonce(AccountAddress)) - require.Equal(t, StoredValue, sdb.GetState(AccountAddress, StorageLeafKey)) - require.Equal(t, AccountCodeHash, sdb.GetCodeHash(AccountAddress)) - require.Equal(t, AccountCode, sdb.GetCode(AccountAddress)) - require.Equal(t, len(AccountCode), sdb.GetCodeSize(AccountAddress)) - } - - require.True(t, sdb.Exist(AccountAddress)) - checkAccountUnchanged() - - id := sdb.Snapshot() - - newStorage := crypto.Keccak256Hash([]byte{5, 4, 3, 2, 1}) - newCode := []byte{1, 3, 3, 7} - - sdb.SetBalance(AccountAddress, big.NewInt(300)) - sdb.AddBalance(AccountAddress, big.NewInt(200)) - sdb.SubBalance(AccountAddress, big.NewInt(100)) - sdb.SetNonce(AccountAddress, 42) - sdb.SetState(AccountAddress, StorageLeafKey, newStorage) - sdb.SetCode(AccountAddress, newCode) - - require.Equal(t, big.NewInt(400), sdb.GetBalance(AccountAddress)) - require.Equal(t, uint64(42), sdb.GetNonce(AccountAddress)) - require.Equal(t, newStorage, sdb.GetState(AccountAddress, StorageLeafKey)) - require.Equal(t, newCode, sdb.GetCode(AccountAddress)) - - sdb.AddSlotToAccessList(AccountAddress, StorageLeafKey) - require.True(t, sdb.AddressInAccessList(AccountAddress)) - hasAddr, hasSlot := sdb.SlotInAccessList(AccountAddress, StorageLeafKey) - require.True(t, hasAddr) - require.True(t, hasSlot) - - sdb.RevertToSnapshot(id) - - checkAccountUnchanged() - require.False(t, sdb.AddressInAccessList(AccountAddress)) - hasAddr, hasSlot = sdb.SlotInAccessList(AccountAddress, StorageLeafKey) - require.False(t, hasAddr) - require.False(t, hasSlot) - }) + testSuite(t, db) } func TestSQLXSuite(t *testing.T) { - testConfig, err := getTestConfig() + testConfig, err := postgres.DefaultConfig.WithEnv() require.NoError(t, err) - pool, err := statedb.NewSQLXPool(testCtx, testConfig) + pool, err := postgres.ConnectSQLX(testCtx, testConfig) if err != nil { t.Fatal(err) } - driver, err := statedb.NewSQLXDriverFromPool(context.Background(), pool) - if err != nil { - t.Fatal(err) - } - database := statedb.NewPostgresDB(driver) t.Cleanup(func() { tx, err := pool.Begin() require.NoError(t, err) @@ -334,6 +138,17 @@ func TestSQLXSuite(t *testing.T) { } require.NoError(t, tx.Commit()) }) + + driver := statedb.NewSQLXDriverFromPool(context.Background(), pool) + database := statedb.NewPostgresDB(driver) + insertSuiteData(t, database) + + db := statedb.NewStateDatabase(database) + require.NoError(t, err) + testSuite(t, db) +} + +func insertSuiteData(t *testing.T, database statedb.Database) { require.NoError(t, insertHeaderCID(database, BlockHash.String(), BlockParentHash.String(), BlockNumber.Uint64())) require.NoError(t, insertHeaderCID(database, BlockHash2.String(), BlockHash.String(), BlockNumber2)) require.NoError(t, insertHeaderCID(database, BlockHash3.String(), BlockHash2.String(), BlockNumber3)) @@ -415,10 +230,9 @@ func TestSQLXSuite(t *testing.T) { Removed: false, })) require.NoError(t, insertContractCode(database)) +} - db, err := statedb.NewStateDatabaseWithSqlxPool(pool) - require.NoError(t, err) - +func testSuite(t *testing.T, db statedb.StateDatabase) { t.Run("Database", func(t *testing.T) { size, err := db.ContractCodeSize(AccountCodeHash) require.NoError(t, err) @@ -648,17 +462,3 @@ func insertContractCode(db statedb.Database) error { _, err := db.Exec(testCtx, sql, BlockNumber.Uint64(), AccountCodeCID.String(), AccountCode) return err } - -func getTestConfig() (conf statedb.Config, err error) { - port, err := strconv.Atoi(os.Getenv("DATABASE_PORT")) - if err != nil { - return - } - return statedb.Config{ - Hostname: os.Getenv("DATABASE_HOSTNAME"), - DatabaseName: os.Getenv("DATABASE_NAME"), - Username: os.Getenv("DATABASE_USER"), - Password: os.Getenv("DATABASE_PASSWORD"), - Port: port, - }, nil -}