diff --git a/pkg/repositories/in_memory.go b/pkg/repositories/in_memory.go index 99343784..5af378fa 100644 --- a/pkg/repositories/in_memory.go +++ b/pkg/repositories/in_memory.go @@ -24,8 +24,9 @@ func NewInMemory() *InMemory { } } -func (repository *InMemory) CreateBlock(block core.Block) { +func (repository *InMemory) CreateBlock(block core.Block) error { repository.blocks[block.Number] = &block + return nil } func (repository *InMemory) BlockCount() int { diff --git a/pkg/repositories/postgres.go b/pkg/repositories/postgres.go index 5672c99d..9d1ea7c3 100644 --- a/pkg/repositories/postgres.go +++ b/pkg/repositories/postgres.go @@ -3,6 +3,9 @@ package repositories import ( "database/sql" + "context" + + "errors" "github.com/8thlight/vulcanizedb/pkg/core" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" @@ -12,6 +15,10 @@ type Postgres struct { Db *sqlx.DB } +var ( + ErrDBInsertFailed = errors.New("postgres: insert failed") +) + func (repository Postgres) MaxBlockNumber() int64 { var highestBlockNumber int64 repository.Db.Get(&highestBlockNumber, `SELECT MAX(block_number) FROM blocks`) @@ -57,25 +64,41 @@ func (repository Postgres) BlockCount() int { return count } -func (repository Postgres) CreateBlock(block core.Block) { - insertedBlock := repository.Db.QueryRow( +func (repository Postgres) CreateBlock(block core.Block) error { + tx, _ := repository.Db.BeginTx(context.Background(), nil) + var blockId int64 + err := tx.QueryRow( `INSERT INTO blocks (block_number, block_gaslimit, block_gasused, block_time, block_difficulty, block_hash, block_nonce, block_parenthash, block_size, uncle_hash) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id`, - block.Number, block.GasLimit, block.GasUsed, block.Time, block.Difficulty, block.Hash, block.Nonce, block.ParentHash, block.Size, block.UncleHash) - var blockId int64 - insertedBlock.Scan(&blockId) - repository.createTransactions(blockId, block.Transactions) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + RETURNING id `, + block.Number, block.GasLimit, block.GasUsed, block.Time, block.Difficulty, block.Hash, block.Nonce, block.ParentHash, block.Size, block.UncleHash). + Scan(&blockId) + if err != nil { + tx.Rollback() + return ErrDBInsertFailed + } + err = repository.createTransactions(tx, blockId, block.Transactions) + if err != nil { + tx.Rollback() + return ErrDBInsertFailed + } + tx.Commit() + return nil } -func (repository Postgres) createTransactions(blockId int64, transactions []core.Transaction) { +func (repository Postgres) createTransactions(tx *sql.Tx, blockId int64, transactions []core.Transaction) error { for _, transaction := range transactions { - repository.Db.MustExec( + _, err := tx.Exec( `INSERT INTO transactions (block_id, tx_hash, tx_nonce, tx_to, tx_gaslimit, tx_gasprice, tx_value) VALUES ($1, $2, $3, $4, $5, $6, $7)`, blockId, transaction.Hash, transaction.Nonce, transaction.To, transaction.GasLimit, transaction.GasPrice, transaction.Value) + if err != nil { + return err + } } + return nil } func (repository Postgres) loadBlock(blockRows *sql.Rows) core.Block { @@ -107,7 +130,7 @@ func (repository Postgres) loadBlock(blockRows *sql.Rows) core.Block { } } func (repository Postgres) loadTransactions(blockId int64) []core.Transaction { - transactionRows, _ := repository.Db.Query("SELECT tx_hash, tx_nonce, tx_to, tx_gaslimit, tx_gasprice, tx_value FROM transactions") + transactionRows, _ := repository.Db.Query(`SELECT tx_hash, tx_nonce, tx_to, tx_gaslimit, tx_gasprice, tx_value FROM transactions`) var transactions []core.Transaction for transactionRows.Next() { var hash string diff --git a/pkg/repositories/repository.go b/pkg/repositories/repository.go index 15d88f76..1df8d339 100644 --- a/pkg/repositories/repository.go +++ b/pkg/repositories/repository.go @@ -3,7 +3,7 @@ package repositories import "github.com/8thlight/vulcanizedb/pkg/core" type Repository interface { - CreateBlock(block core.Block) + CreateBlock(block core.Block) error BlockCount() int FindBlockByNumber(blockNumber int64) *core.Block MaxBlockNumber() int64 diff --git a/pkg/repositories/repository_test.go b/pkg/repositories/repository_test.go index 9167ad8d..df7ea8e7 100644 --- a/pkg/repositories/repository_test.go +++ b/pkg/repositories/repository_test.go @@ -1,6 +1,9 @@ package repositories_test import ( + "fmt" + "strings" + "github.com/8thlight/vulcanizedb/pkg/config" "github.com/8thlight/vulcanizedb/pkg/core" "github.com/8thlight/vulcanizedb/pkg/repositories" @@ -133,6 +136,7 @@ var _ = Describe("Repositories", func() { Expect(savedTransaction.GasPrice).To(Equal(gasPrice)) Expect(savedTransaction.Value).To(Equal(value)) }) + }) Describe("The missing block numbers", func() { @@ -219,6 +223,48 @@ var _ = Describe("Repositories", func() { Expect(db).ShouldNot(BeNil()) }) + It("does not commit block if block is invalid", func() { + + //badNonce violates db Nonce field length + badNonce := fmt.Sprintf("x %s", strings.Repeat("1", 100)) + badBlock := core.Block{ + Number: 123, + Nonce: badNonce, + Transactions: []core.Transaction{}, + } + pgConfig := config.DbConnectionString(config.NewConfig("private").Database) + db, _ := sqlx.Connect("postgres", pgConfig) + Expect(db).ShouldNot(BeNil()) + repository := repositories.NewPostgres(db) + + err := repository.CreateBlock(badBlock) + savedBlock := repository.FindBlockByNumber(123) + + Expect(err).ToNot(BeNil()) + Expect(savedBlock).To(BeNil()) + }) + + It("does not commit block or transactions if transaction is invalid", func() { + + //badHash violates db To field length + badHash := fmt.Sprintf("x %s", strings.Repeat("1", 100)) + badTransaction := core.Transaction{To: badHash} + pgConfig := config.DbConnectionString(config.NewConfig("private").Database) + block := core.Block{ + Number: 123, + Transactions: []core.Transaction{badTransaction}, + } + db, _ := sqlx.Connect("postgres", pgConfig) + Expect(db).ShouldNot(BeNil()) + repository := repositories.NewPostgres(db) + + err := repository.CreateBlock(block) + savedBlock := repository.FindBlockByNumber(123) + + Expect(err).ToNot(BeNil()) + Expect(savedBlock).To(BeNil()) + }) + AssertRepositoryBehavior(func() repositories.Repository { pgConfig := config.DbConnectionString(config.NewConfig("private").Database) db, _ := sqlx.Connect("postgres", pgConfig)