diff --git a/indexer/database/sql/lazy_tx.go b/indexer/database/sql/lazy_tx.go index d34d8ae..701b7cb 100644 --- a/indexer/database/sql/lazy_tx.go +++ b/indexer/database/sql/lazy_tx.go @@ -3,6 +3,7 @@ package sql import ( "context" "reflect" + "sync" "time" "github.com/cerc-io/plugeth-statediff/indexer/database/metrics" @@ -15,6 +16,7 @@ const copyFromCheckLimit = 100 type DelayedTx struct { cache []interface{} db Database + sync.RWMutex } type cachedStmt struct { sql string @@ -27,6 +29,8 @@ type copyFrom struct { rows [][]interface{} } +type result int64 + func (cf *copyFrom) appendRows(rows [][]interface{}) { cf.rows = append(cf.rows, rows...) } @@ -44,6 +48,8 @@ func (tx *DelayedTx) QueryRow(ctx context.Context, sql string, args ...interface } func (tx *DelayedTx) findPrevCopyFrom(tableName []string, columnNames []string, limit int) (*copyFrom, int) { + tx.RLock() + defer tx.RUnlock() for pos, count := len(tx.cache)-1, 0; pos >= 0 && count < limit; pos, count = pos-1, count+1 { prevCopy, ok := tx.cache[pos].(*copyFrom) if ok && prevCopy.matches(tableName, columnNames) { @@ -59,6 +65,8 @@ func (tx *DelayedTx) CopyFrom(ctx context.Context, tableName []string, columnNam "current", len(prevCopy.rows), "new", len(rows), "distance", distance) prevCopy.appendRows(rows) } else { + tx.Lock() + defer tx.Unlock() tx.cache = append(tx.cache, ©From{tableName, columnNames, rows}) } @@ -66,8 +74,10 @@ func (tx *DelayedTx) CopyFrom(ctx context.Context, tableName []string, columnNam } func (tx *DelayedTx) Exec(ctx context.Context, sql string, args ...interface{}) (Result, error) { + tx.Lock() + defer tx.Unlock() tx.cache = append(tx.cache, cachedStmt{sql, args}) - return nil, nil + return result(0), nil } func (tx *DelayedTx) Commit(ctx context.Context) error { @@ -85,6 +95,8 @@ func (tx *DelayedTx) Commit(ctx context.Context) error { rollback(ctx, base) } }() + tx.Lock() + defer tx.Unlock() for _, item := range tx.cache { switch item := item.(type) { case *copyFrom: @@ -105,6 +117,13 @@ func (tx *DelayedTx) Commit(ctx context.Context) error { } func (tx *DelayedTx) Rollback(ctx context.Context) error { + tx.Lock() + defer tx.Unlock() tx.cache = nil return nil } + +// RowsAffected satisfies sql.Result +func (r result) RowsAffected() (int64, error) { + return int64(r), nil +}