propagate db params to unit tests

This commit is contained in:
Roy Crihfield 2023-07-11 18:28:13 +08:00
parent 524fbb13fb
commit 5ab2e31433
9 changed files with 40 additions and 25 deletions

View File

@ -41,7 +41,7 @@ func setupCSVIndexer(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
ind, err = file.NewStateDiffIndexer(mocks.TestConfig, file.CSVTestConfig) ind, err = file.NewStateDiffIndexer(mocks.TestChainConfig, file.CSVTestConfig)
require.NoError(t, err) require.NoError(t, err)
db, err = postgres.SetupSQLXDB() db, err = postgres.SetupSQLXDB()

View File

@ -41,7 +41,7 @@ func setupIndexer(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
ind, err = file.NewStateDiffIndexer(mocks.TestConfig, file.SQLTestConfig) ind, err = file.NewStateDiffIndexer(mocks.TestChainConfig, file.SQLTestConfig)
require.NoError(t, err) require.NoError(t, err)
db, err = postgres.SetupSQLXDB() db, err = postgres.SetupSQLXDB()

View File

@ -28,7 +28,8 @@ import (
) )
func setupLegacyPGXIndexer(t *testing.T) { func setupLegacyPGXIndexer(t *testing.T) {
db, err = postgres.SetupPGXDB(postgres.TestConfig) config, _ := postgres.TestConfig.WithEnv()
db, err = postgres.SetupPGXDB(config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -29,17 +29,27 @@ import (
"github.com/cerc-io/plugeth-statediff/indexer/test" "github.com/cerc-io/plugeth-statediff/indexer/test"
) )
var defaultPgConfig postgres.Config
func init() {
var err error
defaultPgConfig, err = postgres.TestConfig.WithEnv()
if err != nil {
panic(err)
}
}
func setupPGXIndexer(t *testing.T, config postgres.Config) { func setupPGXIndexer(t *testing.T, config postgres.Config) {
db, err = postgres.SetupPGXDB(config) db, err = postgres.SetupPGXDB(config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ind, err = sql.NewStateDiffIndexer(context.Background(), mocks.TestConfig, db) ind, err = sql.NewStateDiffIndexer(context.Background(), mocks.TestChainConfig, db)
require.NoError(t, err) require.NoError(t, err)
} }
func setupPGX(t *testing.T) { func setupPGX(t *testing.T) {
setupPGXWithConfig(t, postgres.TestConfig) setupPGXWithConfig(t, defaultPgConfig)
} }
func setupPGXWithConfig(t *testing.T, config postgres.Config) { func setupPGXWithConfig(t *testing.T, config postgres.Config) {
@ -48,7 +58,7 @@ func setupPGXWithConfig(t *testing.T, config postgres.Config) {
} }
func setupPGXNonCanonical(t *testing.T) { func setupPGXNonCanonical(t *testing.T) {
setupPGXIndexer(t, postgres.TestConfig) setupPGXIndexer(t, defaultPgConfig)
test.SetupTestDataNonCanonical(t, ind) test.SetupTestDataNonCanonical(t, ind)
} }
@ -103,7 +113,7 @@ func TestPGXIndexer(t *testing.T) {
}) })
t.Run("Publish and index with CopyFrom enabled.", func(t *testing.T) { t.Run("Publish and index with CopyFrom enabled.", func(t *testing.T) {
config := postgres.TestConfig config := defaultPgConfig
config.CopyFrom = true config.CopyFrom = true
setupPGXWithConfig(t, config) setupPGXWithConfig(t, config)
@ -169,7 +179,7 @@ func TestPGXIndexerNonCanonical(t *testing.T) {
} }
func TestPGXWatchAddressMethods(t *testing.T) { func TestPGXWatchAddressMethods(t *testing.T) {
setupPGXIndexer(t, postgres.TestConfig) setupPGXIndexer(t, defaultPgConfig)
defer tearDown(t) defer tearDown(t)
defer checkTxClosure(t, 1, 0, 1) defer checkTxClosure(t, 1, 0, 1)

View File

@ -31,8 +31,9 @@ import (
) )
var ( var (
pgConfig, _ = postgres.MakeConfig(postgres.TestConfig) pgConfig, _ = postgres.TestConfig.WithEnv()
ctx = context.Background() pgxConfig, _ = postgres.MakeConfig(pgConfig)
ctx = context.Background()
) )
func expectContainsSubstring(t *testing.T, full string, sub string) { func expectContainsSubstring(t *testing.T, full string, sub string) {
@ -43,9 +44,9 @@ func expectContainsSubstring(t *testing.T, full string, sub string) {
func TestPostgresPGX(t *testing.T) { func TestPostgresPGX(t *testing.T) {
t.Run("connects to the sql", func(t *testing.T) { t.Run("connects to the sql", func(t *testing.T) {
dbPool, err := pgxpool.ConnectConfig(context.Background(), pgConfig) dbPool, err := pgxpool.ConnectConfig(context.Background(), pgxConfig)
if err != nil { if err != nil {
t.Fatalf("failed to connect to db with connection string: %s err: %v", pgConfig.ConnString(), err) t.Fatalf("failed to connect to db with connection string: %s err: %v", pgxConfig.ConnString(), err)
} }
if dbPool == nil { if dbPool == nil {
t.Fatal("DB pool is nil") t.Fatal("DB pool is nil")
@ -61,9 +62,9 @@ func TestPostgresPGX(t *testing.T) {
// sized int, so use string representation of big.Int // sized int, so use string representation of big.Int
// and cast on insert // and cast on insert
dbPool, err := pgxpool.ConnectConfig(context.Background(), pgConfig) dbPool, err := pgxpool.ConnectConfig(context.Background(), pgxConfig)
if err != nil { if err != nil {
t.Fatalf("failed to connect to db with connection string: %s err: %v", pgConfig.ConnString(), err) t.Fatalf("failed to connect to db with connection string: %s err: %v", pgxConfig.ConnString(), err)
} }
defer dbPool.Close() defer dbPool.Close()
@ -111,7 +112,7 @@ func TestPostgresPGX(t *testing.T) {
badHash := fmt.Sprintf("x %s", strings.Repeat("1", 100)) badHash := fmt.Sprintf("x %s", strings.Repeat("1", 100))
badInfo := node.Info{GenesisBlock: badHash, NetworkID: "1", ID: "x123", ClientName: "geth"} badInfo := node.Info{GenesisBlock: badHash, NetworkID: "1", ID: "x123", ClientName: "geth"}
_, err := postgres.NewPGXDriver(ctx, postgres.TestConfig, badInfo) _, err := postgres.NewPGXDriver(ctx, pgConfig, badInfo)
if err == nil { if err == nil {
t.Fatal("Expected an error") t.Fatal("Expected an error")
} }

View File

@ -35,7 +35,7 @@ func TestPostgresSQLX(t *testing.T) {
t.Run("connects to the database", func(t *testing.T) { t.Run("connects to the database", func(t *testing.T) {
var err error var err error
connStr := postgres.TestConfig.DbConnectionString() connStr := pgConfig.DbConnectionString()
sqlxdb, err = sqlx.Connect("postgres", connStr) sqlxdb, err = sqlx.Connect("postgres", connStr)
if err != nil { if err != nil {
@ -58,7 +58,7 @@ func TestPostgresSQLX(t *testing.T) {
// sized int, so use string representation of big.Int // sized int, so use string representation of big.Int
// and cast on insert // and cast on insert
connStr := postgres.TestConfig.DbConnectionString() connStr := pgConfig.DbConnectionString()
db, err := sqlx.Connect("postgres", connStr) db, err := sqlx.Connect("postgres", connStr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -109,7 +109,7 @@ func TestPostgresSQLX(t *testing.T) {
badHash := fmt.Sprintf("x %s", strings.Repeat("1", 100)) badHash := fmt.Sprintf("x %s", strings.Repeat("1", 100))
badInfo := node.Info{GenesisBlock: badHash, NetworkID: "1", ID: "x123", ClientName: "geth"} badInfo := node.Info{GenesisBlock: badHash, NetworkID: "1", ID: "x123", ClientName: "geth"}
_, err := postgres.NewSQLXDriver(ctx, postgres.TestConfig, badInfo) _, err := postgres.NewSQLXDriver(ctx, pgConfig, badInfo)
if err == nil { if err == nil {
t.Fatal("Expected an error") t.Fatal("Expected an error")
} }

View File

@ -25,7 +25,10 @@ import (
// SetupSQLXDB is used to setup a sqlx db for tests // SetupSQLXDB is used to setup a sqlx db for tests
func SetupSQLXDB() (sql.Database, error) { func SetupSQLXDB() (sql.Database, error) {
conf := TestConfig conf, err := TestConfig.WithEnv()
if err != nil {
return nil, err
}
conf.MaxIdle = 0 conf.MaxIdle = 0
driver, err := NewSQLXDriver(context.Background(), conf, node.Info{}) driver, err := NewSQLXDriver(context.Background(), conf, node.Info{})
if err != nil { if err != nil {

View File

@ -34,7 +34,7 @@ func setupSQLXIndexer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ind, err = sql.NewStateDiffIndexer(context.Background(), mocks.TestConfig, db) ind, err = sql.NewStateDiffIndexer(context.Background(), mocks.TestChainConfig, db)
require.NoError(t, err) require.NoError(t, err)
} }

View File

@ -39,8 +39,8 @@ import (
// Test variables // Test variables
var ( var (
// block data // block data
TestConfig = params.MainnetChainConfig TestChainConfig = params.MainnetChainConfig
BlockNumber = TestConfig.LondonBlock BlockNumber = TestChainConfig.LondonBlock
// canonical block at London height // canonical block at London height
// includes 5 transactions: 3 Legacy + 1 EIP-2930 + 1 EIP-1559 // includes 5 transactions: 3 Legacy + 1 EIP-2930 + 1 EIP-1559
@ -55,7 +55,7 @@ var (
BaseFee: big.NewInt(params.InitialBaseFee), BaseFee: big.NewInt(params.InitialBaseFee),
Coinbase: common.HexToAddress("0xaE9BEa628c4Ce503DcFD7E305CaB4e29E7476777"), Coinbase: common.HexToAddress("0xaE9BEa628c4Ce503DcFD7E305CaB4e29E7476777"),
} }
MockTransactions, MockReceipts, SenderAddr = createTransactionsAndReceipts(TestConfig, BlockNumber) MockTransactions, MockReceipts, SenderAddr = createTransactionsAndReceipts(TestChainConfig, BlockNumber)
MockBlock = types.NewBlock(&MockHeader, MockTransactions, nil, MockReceipts, trie.NewEmpty(nil)) MockBlock = types.NewBlock(&MockHeader, MockTransactions, nil, MockReceipts, trie.NewEmpty(nil))
MockHeaderRlp, _ = rlp.EncodeToBytes(MockBlock.Header()) MockHeaderRlp, _ = rlp.EncodeToBytes(MockBlock.Header())
@ -63,7 +63,7 @@ var (
// includes 2nd and 5th transactions from the canonical block // includes 2nd and 5th transactions from the canonical block
MockNonCanonicalHeader = MockHeader MockNonCanonicalHeader = MockHeader
MockNonCanonicalBlockTransactions = types.Transactions{MockTransactions[1], MockTransactions[4]} MockNonCanonicalBlockTransactions = types.Transactions{MockTransactions[1], MockTransactions[4]}
MockNonCanonicalBlockReceipts = createNonCanonicalBlockReceipts(TestConfig, BlockNumber, MockNonCanonicalBlockTransactions) MockNonCanonicalBlockReceipts = createNonCanonicalBlockReceipts(TestChainConfig, BlockNumber, MockNonCanonicalBlockTransactions)
MockNonCanonicalBlock = types.NewBlock(&MockNonCanonicalHeader, MockNonCanonicalBlockTransactions, nil, MockNonCanonicalBlockReceipts, trie.NewEmpty(nil)) MockNonCanonicalBlock = types.NewBlock(&MockNonCanonicalHeader, MockNonCanonicalBlockTransactions, nil, MockNonCanonicalBlockReceipts, trie.NewEmpty(nil))
MockNonCanonicalHeaderRlp, _ = rlp.EncodeToBytes(MockNonCanonicalBlock.Header()) MockNonCanonicalHeaderRlp, _ = rlp.EncodeToBytes(MockNonCanonicalBlock.Header())
@ -82,7 +82,7 @@ var (
Coinbase: common.HexToAddress("0xaE9BEa628c4Ce503DcFD7E305CaB4e29E7476777"), Coinbase: common.HexToAddress("0xaE9BEa628c4Ce503DcFD7E305CaB4e29E7476777"),
} }
MockNonCanonicalBlock2Transactions = types.Transactions{MockTransactions[2], MockTransactions[4]} MockNonCanonicalBlock2Transactions = types.Transactions{MockTransactions[2], MockTransactions[4]}
MockNonCanonicalBlock2Receipts = createNonCanonicalBlockReceipts(TestConfig, Block2Number, MockNonCanonicalBlock2Transactions) MockNonCanonicalBlock2Receipts = createNonCanonicalBlockReceipts(TestChainConfig, Block2Number, MockNonCanonicalBlock2Transactions)
MockNonCanonicalBlock2 = types.NewBlock(&MockNonCanonicalHeader2, MockNonCanonicalBlock2Transactions, nil, MockNonCanonicalBlock2Receipts, trie.NewEmpty(nil)) MockNonCanonicalBlock2 = types.NewBlock(&MockNonCanonicalHeader2, MockNonCanonicalBlock2Transactions, nil, MockNonCanonicalBlock2Receipts, trie.NewEmpty(nil))
MockNonCanonicalHeader2Rlp, _ = rlp.EncodeToBytes(MockNonCanonicalBlock2.Header()) MockNonCanonicalHeader2Rlp, _ = rlp.EncodeToBytes(MockNonCanonicalBlock2.Header())