Refactor redundant DB code

This commit is contained in:
Roy Crihfield 2023-04-11 17:42:11 +08:00
parent ced8041f72
commit 49177ee0ab
5 changed files with 32 additions and 359 deletions

View File

@ -1,90 +0,0 @@
package ipld_eth_statedb
import (
"context"
"fmt"
"time"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
)
type Config struct {
Hostname string
Port int
DatabaseName string
Username string
Password string
ConnTimeout time.Duration
MaxConns int
MinConns int
MaxConnLifetime time.Duration
MaxConnIdleTime time.Duration
MaxIdle int
}
// DbConnectionString constructs and returns the connection string from the config (for sqlx driver)
func (c Config) DbConnectionString() string {
if len(c.Username) > 0 && len(c.Password) > 0 {
return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable",
c.Username, c.Password, c.Hostname, c.Port, c.DatabaseName)
}
if len(c.Username) > 0 && len(c.Password) == 0 {
return fmt.Sprintf("postgresql://%s@%s:%d/%s?sslmode=disable",
c.Username, c.Hostname, c.Port, c.DatabaseName)
}
return fmt.Sprintf("postgresql://%s:%d/%s?sslmode=disable", c.Hostname, c.Port, c.DatabaseName)
}
// NewPGXPool returns a new pgx conn pool
func NewPGXPool(ctx context.Context, config Config) (*pgxpool.Pool, error) {
pgConf, err := makePGXConfig(config)
if err != nil {
return nil, err
}
return pgxpool.ConnectConfig(ctx, pgConf)
}
// NewSQLXPool returns a new sqlx conn pool
func NewSQLXPool(ctx context.Context, config Config) (*sqlx.DB, error) {
db, err := sqlx.ConnectContext(ctx, "postgres", config.DbConnectionString())
if err != nil {
return nil, err
}
return db, nil
}
// makePGXConfig creates a pgxpool.Config from the provided Config
func makePGXConfig(config Config) (*pgxpool.Config, error) {
conf, err := pgxpool.ParseConfig("")
if err != nil {
return nil, err
}
conf.ConnConfig.Config.Host = config.Hostname
conf.ConnConfig.Config.Port = uint16(config.Port)
conf.ConnConfig.Config.Database = config.DatabaseName
conf.ConnConfig.Config.User = config.Username
conf.ConnConfig.Config.Password = config.Password
if config.ConnTimeout != 0 {
conf.ConnConfig.Config.ConnectTimeout = config.ConnTimeout
}
if config.MaxConns != 0 {
conf.MaxConns = int32(config.MaxConns)
}
if config.MinConns != 0 {
conf.MinConns = int32(config.MinConns)
}
if config.MaxConnLifetime != 0 {
conf.MaxConnLifetime = config.MaxConnLifetime
}
if config.MaxConnIdleTime != 0 {
conf.MaxConnIdleTime = config.MaxConnIdleTime
}
return conf, nil
}

13
pgx.go
View File

@ -15,18 +15,9 @@ type PGXDriver struct {
db *pgxpool.Pool db *pgxpool.Pool
} }
// NewPGXDriver returns a new pgx driver for Postgres
func NewPGXDriver(ctx context.Context, config Config) (*PGXDriver, error) {
db, err := NewPGXPool(ctx, config)
if err != nil {
return nil, err
}
return &PGXDriver{ctx: ctx, db: db}, nil
}
// NewPGXDriverFromPool returns a new pgx driver for Postgres // NewPGXDriverFromPool returns a new pgx driver for Postgres
func NewPGXDriverFromPool(ctx context.Context, db *pgxpool.Pool) (*PGXDriver, error) { func NewPGXDriverFromPool(ctx context.Context, db *pgxpool.Pool) *PGXDriver {
return &PGXDriver{ctx: ctx, db: db}, nil return &PGXDriver{ctx: ctx, db: db}
} }
// QueryRow satisfies sql.Database // QueryRow satisfies sql.Database

20
sqlx.go
View File

@ -14,25 +14,9 @@ type SQLXDriver struct {
db *sqlx.DB db *sqlx.DB
} }
// NewSQLXDriver returns a new sqlx driver for Postgres
func NewSQLXDriver(ctx context.Context, config Config) (*SQLXDriver, error) {
db, err := NewSQLXPool(ctx, config)
if err != nil {
return nil, err
}
if config.MaxConns > 0 {
db.SetMaxOpenConns(config.MaxConns)
}
if config.MaxConnLifetime > 0 {
db.SetConnMaxLifetime(config.MaxConnLifetime)
}
db.SetMaxIdleConns(config.MaxIdle)
return &SQLXDriver{ctx: ctx, db: db}, nil
}
// NewSQLXDriverFromPool returns a new sqlx driver for Postgres // NewSQLXDriverFromPool returns a new sqlx driver for Postgres
func NewSQLXDriverFromPool(ctx context.Context, db *sqlx.DB) (*SQLXDriver, error) { func NewSQLXDriverFromPool(ctx context.Context, db *sqlx.DB) *SQLXDriver {
return &SQLXDriver{ctx: ctx, db: db}, nil return &SQLXDriver{ctx: ctx, db: db}
} }
// QueryRow satisfies sql.Database // QueryRow satisfies sql.Database

View File

@ -8,8 +8,6 @@ import (
"github.com/VictoriaMetrics/fastcache" "github.com/VictoriaMetrics/fastcache"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/jmoiron/sqlx"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
@ -48,24 +46,14 @@ type stateDatabase struct {
codeCache *fastcache.Cache codeCache *fastcache.Cache
} }
// NewStateDatabaseWithPgxPool returns a new Database implementation using the provided postgres connection pool // NewStateDatabase returns a new Database implementation using the passed parameters
func NewStateDatabaseWithPgxPool(pgDb *pgxpool.Pool) (*stateDatabase, error) { func NewStateDatabase(db Database) *stateDatabase {
csc, _ := lru.New(codeSizeCacheSize) csc, _ := lru.New(codeSizeCacheSize)
return &stateDatabase{ return &stateDatabase{
db: NewPostgresDB(&PGXDriver{db: pgDb}), db: db,
codeSizeCache: csc, codeSizeCache: csc,
codeCache: fastcache.New(codeCacheSize), codeCache: fastcache.New(codeCacheSize),
}, nil
} }
// NewStateDatabaseWithSqlxPool returns a new Database implementation using the passed parameters
func NewStateDatabaseWithSqlxPool(db *sqlx.DB) (*stateDatabase, error) {
csc, _ := lru.New(codeSizeCacheSize)
return &stateDatabase{
db: NewPostgresDB(&SQLXDriver{db: db}),
codeSizeCache: csc,
codeCache: fastcache.New(codeCacheSize),
}, nil
} }
// ContractCode satisfies Database, it returns the contract code for a given codehash // ContractCode satisfies Database, it returns the contract code for a given codehash

View File

@ -3,8 +3,6 @@ package ipld_eth_statedb_test
import ( import (
"context" "context"
"math/big" "math/big"
"os"
"strconv"
"testing" "testing"
"github.com/lib/pq" "github.com/lib/pq"
@ -15,6 +13,7 @@ import (
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld" "github.com/ethereum/go-ethereum/statediff/indexer/ipld"
statedb "github.com/cerc-io/ipld-eth-statedb" statedb "github.com/cerc-io/ipld-eth-statedb"
@ -86,17 +85,12 @@ var (
) )
func TestPGXSuite(t *testing.T) { func TestPGXSuite(t *testing.T) {
testConfig, err := getTestConfig() testConfig, err := postgres.DefaultConfig.WithEnv()
require.NoError(t, err) require.NoError(t, err)
pool, err := statedb.NewPGXPool(testCtx, testConfig) pool, err := postgres.ConnectPGX(testCtx, testConfig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
driver, err := statedb.NewPGXDriverFromPool(context.Background(), pool)
if err != nil {
t.Fatal(err)
}
database := statedb.NewPostgresDB(driver)
t.Cleanup(func() { t.Cleanup(func() {
tx, err := pool.Begin(testCtx) tx, err := pool.Begin(testCtx)
require.NoError(t, err) require.NoError(t, err)
@ -112,213 +106,23 @@ func TestPGXSuite(t *testing.T) {
} }
require.NoError(t, tx.Commit(testCtx)) require.NoError(t, tx.Commit(testCtx))
}) })
require.NoError(t, insertHeaderCID(database, BlockHash.String(), BlockParentHash.String(), BlockNumber.Uint64()))
require.NoError(t, insertHeaderCID(database, BlockHash2.String(), BlockHash.String(), BlockNumber2))
require.NoError(t, insertHeaderCID(database, BlockHash3.String(), BlockHash2.String(), BlockNumber3))
require.NoError(t, insertHeaderCID(database, BlockHash4.String(), BlockHash3.String(), BlockNumber4))
require.NoError(t, insertHeaderCID(database, NonCanonicalHash4.String(), BlockHash3.String(), BlockNumber4))
require.NoError(t, insertHeaderCID(database, BlockHash5.String(), BlockHash4.String(), BlockNumber5))
require.NoError(t, insertHeaderCID(database, NonCanonicalHash5.String(), NonCanonicalHash4.String(), BlockNumber5))
require.NoError(t, insertHeaderCID(database, BlockHash6.String(), BlockHash5.String(), BlockNumber6))
require.NoError(t, insertStateCID(database, stateModel{
BlockNumber: BlockNumber.Uint64(),
BlockHash: BlockHash.String(),
LeafKey: AccountLeafKey.String(),
CID: AccountCID.String(),
Diff: true,
Balance: Account.Balance.Uint64(),
Nonce: Account.Nonce,
CodeHash: AccountCodeHash.String(),
StorageRoot: Account.Root.String(),
Removed: false,
}))
require.NoError(t, insertStateCID(database, stateModel{
BlockNumber: BlockNumber4,
BlockHash: NonCanonicalHash4.String(),
LeafKey: AccountLeafKey.String(),
CID: AccountCID.String(),
Diff: true,
Balance: big.NewInt(123).Uint64(),
Nonce: Account.Nonce,
CodeHash: AccountCodeHash.String(),
StorageRoot: Account.Root.String(),
Removed: false,
}))
require.NoError(t, insertStateCID(database, stateModel{
BlockNumber: BlockNumber5,
BlockHash: BlockHash5.String(),
LeafKey: AccountLeafKey.String(),
CID: RemovedNodeStateCID,
Diff: true,
Removed: true,
}))
require.NoError(t, insertStorageCID(database, storageModel{
BlockNumber: BlockNumber.Uint64(),
BlockHash: BlockHash.String(),
LeafKey: AccountLeafKey.String(),
StorageLeafKey: StorageLeafKey.String(),
StorageCID: StorageCID.String(),
Diff: true,
Value: StoredValueRLP,
Removed: false,
}))
require.NoError(t, insertStorageCID(database, storageModel{
BlockNumber: BlockNumber2,
BlockHash: BlockHash2.String(),
LeafKey: AccountLeafKey.String(),
StorageLeafKey: StorageLeafKey.String(),
StorageCID: RemovedNodeStorageCID,
Diff: true,
Value: []byte{},
Removed: true,
}))
require.NoError(t, insertStorageCID(database, storageModel{
BlockNumber: BlockNumber3,
BlockHash: BlockHash3.String(),
LeafKey: AccountLeafKey.String(),
StorageLeafKey: StorageLeafKey.String(),
StorageCID: StorageCID.String(),
Diff: true,
Value: StoredValueRLP2,
Removed: false,
}))
require.NoError(t, insertStorageCID(database, storageModel{
BlockNumber: BlockNumber4,
BlockHash: NonCanonicalHash4.String(),
LeafKey: AccountLeafKey.String(),
StorageLeafKey: StorageLeafKey.String(),
StorageCID: StorageCID.String(),
Diff: true,
Value: NonCanonStoredValueRLP,
Removed: false,
}))
require.NoError(t, insertContractCode(database))
db, err := statedb.NewStateDatabaseWithPgxPool(pool) driver := statedb.NewPGXDriverFromPool(context.Background(), pool)
database := statedb.NewPostgresDB(driver)
insertSuiteData(t, database)
db := statedb.NewStateDatabase(database)
require.NoError(t, err) require.NoError(t, err)
testSuite(t, db)
t.Run("Database", func(t *testing.T) {
size, err := db.ContractCodeSize(AccountCodeHash)
require.NoError(t, err)
require.Equal(t, len(AccountCode), size)
code, err := db.ContractCode(AccountCodeHash)
require.NoError(t, err)
require.Equal(t, AccountCode, code)
acct, err := db.StateAccount(AccountLeafKey, BlockHash)
require.NoError(t, err)
require.Equal(t, &Account, acct)
acct2, err := db.StateAccount(AccountLeafKey, BlockHash2)
require.NoError(t, err)
require.Equal(t, &Account, acct2)
acct3, err := db.StateAccount(AccountLeafKey, BlockHash3)
require.NoError(t, err)
require.Equal(t, &Account, acct3)
// check that we don't get the non-canonical account
acct4, err := db.StateAccount(AccountLeafKey, BlockHash4)
require.NoError(t, err)
require.Equal(t, &Account, acct4)
acct5, err := db.StateAccount(AccountLeafKey, BlockHash5)
require.NoError(t, err)
require.Nil(t, acct5)
acct6, err := db.StateAccount(AccountLeafKey, BlockHash6)
require.NoError(t, err)
require.Nil(t, acct6)
val, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash)
require.NoError(t, err)
require.Equal(t, StoredValueRLP, val)
val2, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash2)
require.NoError(t, err)
require.Nil(t, val2)
val3, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash3)
require.NoError(t, err)
require.Equal(t, StoredValueRLP2, val3)
// this checks that we don't get the non-canonical result
val4, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash4)
require.NoError(t, err)
require.Equal(t, StoredValueRLP2, val4)
// this checks that when the entire account was deleted, we return nil result for storage slot
val5, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash5)
require.NoError(t, err)
require.Nil(t, val5)
val6, err := db.StorageValue(AccountLeafKey, StorageLeafKey, BlockHash6)
require.NoError(t, err)
require.Nil(t, val6)
})
t.Run("StateDB", func(t *testing.T) {
sdb, err := statedb.New(BlockHash, db)
require.NoError(t, err)
checkAccountUnchanged := func() {
require.Equal(t, Account.Balance, sdb.GetBalance(AccountAddress))
require.Equal(t, Account.Nonce, sdb.GetNonce(AccountAddress))
require.Equal(t, StoredValue, sdb.GetState(AccountAddress, StorageLeafKey))
require.Equal(t, AccountCodeHash, sdb.GetCodeHash(AccountAddress))
require.Equal(t, AccountCode, sdb.GetCode(AccountAddress))
require.Equal(t, len(AccountCode), sdb.GetCodeSize(AccountAddress))
}
require.True(t, sdb.Exist(AccountAddress))
checkAccountUnchanged()
id := sdb.Snapshot()
newStorage := crypto.Keccak256Hash([]byte{5, 4, 3, 2, 1})
newCode := []byte{1, 3, 3, 7}
sdb.SetBalance(AccountAddress, big.NewInt(300))
sdb.AddBalance(AccountAddress, big.NewInt(200))
sdb.SubBalance(AccountAddress, big.NewInt(100))
sdb.SetNonce(AccountAddress, 42)
sdb.SetState(AccountAddress, StorageLeafKey, newStorage)
sdb.SetCode(AccountAddress, newCode)
require.Equal(t, big.NewInt(400), sdb.GetBalance(AccountAddress))
require.Equal(t, uint64(42), sdb.GetNonce(AccountAddress))
require.Equal(t, newStorage, sdb.GetState(AccountAddress, StorageLeafKey))
require.Equal(t, newCode, sdb.GetCode(AccountAddress))
sdb.AddSlotToAccessList(AccountAddress, StorageLeafKey)
require.True(t, sdb.AddressInAccessList(AccountAddress))
hasAddr, hasSlot := sdb.SlotInAccessList(AccountAddress, StorageLeafKey)
require.True(t, hasAddr)
require.True(t, hasSlot)
sdb.RevertToSnapshot(id)
checkAccountUnchanged()
require.False(t, sdb.AddressInAccessList(AccountAddress))
hasAddr, hasSlot = sdb.SlotInAccessList(AccountAddress, StorageLeafKey)
require.False(t, hasAddr)
require.False(t, hasSlot)
})
} }
func TestSQLXSuite(t *testing.T) { func TestSQLXSuite(t *testing.T) {
testConfig, err := getTestConfig() testConfig, err := postgres.DefaultConfig.WithEnv()
require.NoError(t, err) require.NoError(t, err)
pool, err := statedb.NewSQLXPool(testCtx, testConfig) pool, err := postgres.ConnectSQLX(testCtx, testConfig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
driver, err := statedb.NewSQLXDriverFromPool(context.Background(), pool)
if err != nil {
t.Fatal(err)
}
database := statedb.NewPostgresDB(driver)
t.Cleanup(func() { t.Cleanup(func() {
tx, err := pool.Begin() tx, err := pool.Begin()
require.NoError(t, err) require.NoError(t, err)
@ -334,6 +138,17 @@ func TestSQLXSuite(t *testing.T) {
} }
require.NoError(t, tx.Commit()) require.NoError(t, tx.Commit())
}) })
driver := statedb.NewSQLXDriverFromPool(context.Background(), pool)
database := statedb.NewPostgresDB(driver)
insertSuiteData(t, database)
db := statedb.NewStateDatabase(database)
require.NoError(t, err)
testSuite(t, db)
}
func insertSuiteData(t *testing.T, database statedb.Database) {
require.NoError(t, insertHeaderCID(database, BlockHash.String(), BlockParentHash.String(), BlockNumber.Uint64())) require.NoError(t, insertHeaderCID(database, BlockHash.String(), BlockParentHash.String(), BlockNumber.Uint64()))
require.NoError(t, insertHeaderCID(database, BlockHash2.String(), BlockHash.String(), BlockNumber2)) require.NoError(t, insertHeaderCID(database, BlockHash2.String(), BlockHash.String(), BlockNumber2))
require.NoError(t, insertHeaderCID(database, BlockHash3.String(), BlockHash2.String(), BlockNumber3)) require.NoError(t, insertHeaderCID(database, BlockHash3.String(), BlockHash2.String(), BlockNumber3))
@ -415,10 +230,9 @@ func TestSQLXSuite(t *testing.T) {
Removed: false, Removed: false,
})) }))
require.NoError(t, insertContractCode(database)) require.NoError(t, insertContractCode(database))
}
db, err := statedb.NewStateDatabaseWithSqlxPool(pool) func testSuite(t *testing.T, db statedb.StateDatabase) {
require.NoError(t, err)
t.Run("Database", func(t *testing.T) { t.Run("Database", func(t *testing.T) {
size, err := db.ContractCodeSize(AccountCodeHash) size, err := db.ContractCodeSize(AccountCodeHash)
require.NoError(t, err) require.NoError(t, err)
@ -648,17 +462,3 @@ func insertContractCode(db statedb.Database) error {
_, err := db.Exec(testCtx, sql, BlockNumber.Uint64(), AccountCodeCID.String(), AccountCode) _, err := db.Exec(testCtx, sql, BlockNumber.Uint64(), AccountCodeCID.String(), AccountCode)
return err return err
} }
func getTestConfig() (conf statedb.Config, err error) {
port, err := strconv.Atoi(os.Getenv("DATABASE_PORT"))
if err != nil {
return
}
return statedb.Config{
Hostname: os.Getenv("DATABASE_HOSTNAME"),
DatabaseName: os.Getenv("DATABASE_NAME"),
Username: os.Getenv("DATABASE_USER"),
Password: os.Getenv("DATABASE_PASSWORD"),
Port: port,
}, nil
}