Use *sqlx.Tx instead of *sql.Tx

- requires using db.Beginx() instead of db.Begin()
- enables calling tx.Get()
This commit is contained in:
Rob Mulholand 2019-03-06 10:15:32 -06:00
parent 60d7b34471
commit 1414779d52
9 changed files with 21 additions and 23 deletions

View File

@ -18,9 +18,9 @@ package repository
import ( import (
"bytes" "bytes"
"database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"github.com/jmoiron/sqlx"
"github.com/vulcanize/vulcanizedb/libraries/shared/constants" "github.com/vulcanize/vulcanizedb/libraries/shared/constants"
"github.com/vulcanize/vulcanizedb/pkg/core" "github.com/vulcanize/vulcanizedb/pkg/core"
@ -35,7 +35,7 @@ func MarkHeaderChecked(headerID int64, db *postgres.DB, checkedHeadersColumn str
return err return err
} }
func MarkHeaderCheckedInTransaction(headerID int64, tx *sql.Tx, checkedHeadersColumn string) error { func MarkHeaderCheckedInTransaction(headerID int64, tx *sqlx.Tx, checkedHeadersColumn string) error {
_, err := tx.Exec(`INSERT INTO public.checked_headers (header_id, `+checkedHeadersColumn+`) _, err := tx.Exec(`INSERT INTO public.checked_headers (header_id, `+checkedHeadersColumn+`)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT (header_id) DO ON CONFLICT (header_id) DO

View File

@ -97,7 +97,7 @@ var _ = Describe("Repository", func() {
headerRepository := repositories.NewHeaderRepository(db) headerRepository := repositories.NewHeaderRepository(db)
headerID, headerErr := headerRepository.CreateOrUpdateHeader(fakes.FakeHeader) headerID, headerErr := headerRepository.CreateOrUpdateHeader(fakes.FakeHeader)
Expect(headerErr).NotTo(HaveOccurred()) Expect(headerErr).NotTo(HaveOccurred())
tx, txErr := db.Begin() tx, txErr := db.Beginx()
Expect(txErr).NotTo(HaveOccurred()) Expect(txErr).NotTo(HaveOccurred())
err := shared.MarkHeaderCheckedInTransaction(headerID, tx, checkedHeadersColumn) err := shared.MarkHeaderCheckedInTransaction(headerID, tx, checkedHeadersColumn)

View File

@ -17,7 +17,6 @@
package repositories package repositories
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -122,7 +121,7 @@ func (blockRepository BlockRepository) GetBlock(blockNumber int64) (core.Block,
func (blockRepository BlockRepository) insertBlock(block core.Block) (int64, error) { func (blockRepository BlockRepository) insertBlock(block core.Block) (int64, error) {
var blockId int64 var blockId int64
tx, _ := blockRepository.database.BeginTx(context.Background(), nil) tx, _ := blockRepository.database.Beginx()
err := tx.QueryRow( err := tx.QueryRow(
`INSERT INTO blocks `INSERT INTO blocks
(eth_node_id, number, gaslimit, gasused, time, difficulty, hash, nonce, parenthash, size, uncle_hash, is_final, miner, extra_data, reward, uncles_reward, eth_node_fingerprint) (eth_node_id, number, gaslimit, gasused, time, difficulty, hash, nonce, parenthash, size, uncle_hash, is_final, miner, extra_data, reward, uncles_reward, eth_node_fingerprint)
@ -145,7 +144,7 @@ func (blockRepository BlockRepository) insertBlock(block core.Block) (int64, err
return blockId, nil return blockId, nil
} }
func (blockRepository BlockRepository) createTransactions(tx *sql.Tx, blockId int64, transactions []core.Transaction) error { func (blockRepository BlockRepository) createTransactions(tx *sqlx.Tx, blockId int64, transactions []core.Transaction) error {
for _, transaction := range transactions { for _, transaction := range transactions {
err := blockRepository.createTransaction(tx, blockId, transaction) err := blockRepository.createTransaction(tx, blockId, transaction)
if err != nil { if err != nil {
@ -165,7 +164,7 @@ func nullStringToZero(s string) string {
return s return s
} }
func (blockRepository BlockRepository) createTransaction(tx *sql.Tx, blockId int64, transaction core.Transaction) error { func (blockRepository BlockRepository) createTransaction(tx *sqlx.Tx, blockId int64, transaction core.Transaction) error {
_, err := tx.Exec( _, err := tx.Exec(
`INSERT INTO transactions `INSERT INTO transactions
(block_id, hash, nonce, tx_to, tx_from, gaslimit, gasprice, value, input_data) (block_id, hash, nonce, tx_to, tx_from, gaslimit, gasprice, value, input_data)
@ -198,7 +197,7 @@ func hasReceipt(transaction core.Transaction) bool {
return transaction.Receipt.TxHash != "" return transaction.Receipt.TxHash != ""
} }
func (blockRepository BlockRepository) createReceipt(tx *sql.Tx, blockId int64, receipt core.Receipt) (int, error) { func (blockRepository BlockRepository) createReceipt(tx *sqlx.Tx, blockId int64, receipt core.Receipt) (int, error) {
//Not currently persisting log bloom filters //Not currently persisting log bloom filters
var receiptId int var receiptId int
err := tx.QueryRow( err := tx.QueryRow(
@ -224,7 +223,7 @@ func (blockRepository BlockRepository) getBlockHash(block core.Block) (string, b
return retrievedBlockHash, blockExists(retrievedBlockHash) return retrievedBlockHash, blockExists(retrievedBlockHash)
} }
func (blockRepository BlockRepository) createLogs(tx *sql.Tx, logs []core.Log, receiptId int) error { func (blockRepository BlockRepository) createLogs(tx *sqlx.Tx, logs []core.Log, receiptId int) error {
for _, tlog := range logs { for _, tlog := range logs {
_, err := tx.Exec( _, err := tx.Exec(
`INSERT INTO logs (block_number, address, tx_hash, index, topic0, topic1, topic2, topic3, data, receipt_id) `INSERT INTO logs (block_number, address, tx_hash, index, topic0, topic1, topic2, topic3, data, receipt_id)

View File

@ -17,7 +17,6 @@
package repositories package repositories
import ( import (
"context"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"database/sql" "database/sql"
@ -31,7 +30,7 @@ type LogRepository struct {
} }
func (logRepository LogRepository) CreateLogs(lgs []core.Log, receiptId int64) error { func (logRepository LogRepository) CreateLogs(lgs []core.Log, receiptId int64) error {
tx, _ := logRepository.DB.BeginTx(context.Background(), nil) tx, _ := logRepository.DB.Beginx()
for _, tlog := range lgs { for _, tlog := range lgs {
_, err := tx.Exec( _, err := tx.Exec(
`INSERT INTO logs (block_number, address, tx_hash, index, topic0, topic1, topic2, topic3, data, receipt_id) `INSERT INTO logs (block_number, address, tx_hash, index, topic0, topic1, topic2, topic3, data, receipt_id)

View File

@ -17,8 +17,8 @@
package repositories package repositories
import ( import (
"context"
"database/sql" "database/sql"
"github.com/jmoiron/sqlx"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/vulcanize/vulcanizedb/pkg/core" "github.com/vulcanize/vulcanizedb/pkg/core"
@ -31,7 +31,7 @@ type ReceiptRepository struct {
} }
func (receiptRepository ReceiptRepository) CreateReceiptsAndLogs(blockId int64, receipts []core.Receipt) error { func (receiptRepository ReceiptRepository) CreateReceiptsAndLogs(blockId int64, receipts []core.Receipt) error {
tx, err := receiptRepository.DB.BeginTx(context.Background(), nil) tx, err := receiptRepository.DB.Beginx()
if err != nil { if err != nil {
return err return err
} }
@ -53,7 +53,7 @@ func (receiptRepository ReceiptRepository) CreateReceiptsAndLogs(blockId int64,
return nil return nil
} }
func createReceipt(receipt core.Receipt, blockId int64, tx *sql.Tx) (int64, error) { func createReceipt(receipt core.Receipt, blockId int64, tx *sqlx.Tx) (int64, error) {
var receiptId int64 var receiptId int64
err := tx.QueryRow( err := tx.QueryRow(
`INSERT INTO receipts `INSERT INTO receipts
@ -68,7 +68,7 @@ func createReceipt(receipt core.Receipt, blockId int64, tx *sql.Tx) (int64, erro
return receiptId, err return receiptId, err
} }
func createLogs(logs []core.Log, receiptId int64, tx *sql.Tx) error { func createLogs(logs []core.Log, receiptId int64, tx *sqlx.Tx) error {
for _, log := range logs { for _, log := range logs {
_, err := tx.Exec( _, err := tx.Exec(
`INSERT INTO logs (block_number, address, tx_hash, index, topic0, topic1, topic2, topic3, data, receipt_id) `INSERT INTO logs (block_number, address, tx_hash, index, topic0, topic1, topic2, topic3, data, receipt_id)
@ -84,7 +84,7 @@ func createLogs(logs []core.Log, receiptId int64, tx *sql.Tx) error {
} }
func (receiptRepository ReceiptRepository) CreateReceipt(blockId int64, receipt core.Receipt) (int64, error) { func (receiptRepository ReceiptRepository) CreateReceipt(blockId int64, receipt core.Receipt) (int64, error) {
tx, _ := receiptRepository.DB.BeginTx(context.Background(), nil) tx, _ := receiptRepository.DB.Beginx()
var receiptId int64 var receiptId int64
err := tx.QueryRow( err := tx.QueryRow(
`INSERT INTO receipts `INSERT INTO receipts

View File

@ -17,8 +17,8 @@
package repository package repository
import ( import (
"database/sql"
"fmt" "fmt"
"github.com/jmoiron/sqlx"
"github.com/hashicorp/golang-lru" "github.com/hashicorp/golang-lru"
@ -125,7 +125,7 @@ func (r *headerRepository) MarkHeaderCheckedForAll(headerID int64, ids []string)
} }
func (r *headerRepository) MarkHeadersCheckedForAll(headers []core.Header, ids []string) error { func (r *headerRepository) MarkHeadersCheckedForAll(headers []core.Header, ids []string) error {
tx, err := r.db.Begin() tx, err := r.db.Beginx()
if err != nil { if err != nil {
return err return err
} }
@ -250,7 +250,7 @@ func (r *headerRepository) CheckCache(key string) (interface{}, bool) {
return r.columns.Get(key) return r.columns.Get(key)
} }
func MarkHeaderCheckedInTransaction(headerID int64, tx *sql.Tx, eventID string) error { func MarkHeaderCheckedInTransaction(headerID int64, tx *sqlx.Tx, eventID string) error {
_, err := tx.Exec(`INSERT INTO public.checked_headers (header_id, `+eventID+`) _, err := tx.Exec(`INSERT INTO public.checked_headers (header_id, `+eventID+`)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT (header_id) DO ON CONFLICT (header_id) DO

View File

@ -238,7 +238,7 @@ func SetupENSContract(wantedEvents, wantedMethods []string) *contract.Contract {
} }
func TearDown(db *postgres.DB) { func TearDown(db *postgres.DB) {
tx, err := db.Begin() tx, err := db.Beginx()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
_, err = tx.Exec(`DELETE FROM blocks`) _, err = tx.Exec(`DELETE FROM blocks`)

View File

@ -97,7 +97,7 @@ func (r *eventRepository) persistLogs(logs []types.Log, eventInfo types.Event, c
// Creates a custom postgres command to persist logs for the given event (compatible with light synced vDB) // Creates a custom postgres command to persist logs for the given event (compatible with light synced vDB)
func (r *eventRepository) persistLightSyncLogs(logs []types.Log, eventInfo types.Event, contractAddr, contractName string) error { func (r *eventRepository) persistLightSyncLogs(logs []types.Log, eventInfo types.Event, contractAddr, contractName string) error {
tx, err := r.db.Begin() tx, err := r.db.Beginx()
if err != nil { if err != nil {
return err return err
} }
@ -151,7 +151,7 @@ func (r *eventRepository) persistLightSyncLogs(logs []types.Log, eventInfo types
// Creates a custom postgres command to persist logs for the given event (compatible with fully synced vDB) // Creates a custom postgres command to persist logs for the given event (compatible with fully synced vDB)
func (r *eventRepository) persistFullSyncLogs(logs []types.Log, eventInfo types.Event, contractAddr, contractName string) error { func (r *eventRepository) persistFullSyncLogs(logs []types.Log, eventInfo types.Event, contractAddr, contractName string) error {
tx, err := r.db.Begin() tx, err := r.db.Beginx()
if err != nil { if err != nil {
return err return err
} }

View File

@ -77,7 +77,7 @@ func (r *methodRepository) PersistResults(results []types.Result, methodInfo typ
// Creates a custom postgres command to persist logs for the given event // Creates a custom postgres command to persist logs for the given event
func (r *methodRepository) persistResults(results []types.Result, methodInfo types.Method, contractAddr, contractName string) error { func (r *methodRepository) persistResults(results []types.Result, methodInfo types.Method, contractAddr, contractName string) error {
tx, err := r.DB.Begin() tx, err := r.DB.Beginx()
if err != nil { if err != nil {
return err return err
} }