guard delayed tx cache

This commit is contained in:
Roy Crihfield 2023-10-03 18:34:13 +08:00
parent e70c55bf86
commit 6ebf8471fb

View File

@ -3,6 +3,7 @@ package sql
import ( import (
"context" "context"
"reflect" "reflect"
"sync"
"time" "time"
"github.com/cerc-io/plugeth-statediff/indexer/database/metrics" "github.com/cerc-io/plugeth-statediff/indexer/database/metrics"
@ -15,6 +16,7 @@ const copyFromCheckLimit = 100
type DelayedTx struct { type DelayedTx struct {
cache []interface{} cache []interface{}
db Database db Database
sync.RWMutex
} }
type cachedStmt struct { type cachedStmt struct {
sql string sql string
@ -27,6 +29,8 @@ type copyFrom struct {
rows [][]interface{} rows [][]interface{}
} }
type result int64
func (cf *copyFrom) appendRows(rows [][]interface{}) { func (cf *copyFrom) appendRows(rows [][]interface{}) {
cf.rows = append(cf.rows, rows...) cf.rows = append(cf.rows, rows...)
} }
@ -44,6 +48,8 @@ func (tx *DelayedTx) QueryRow(ctx context.Context, sql string, args ...interface
} }
func (tx *DelayedTx) findPrevCopyFrom(tableName []string, columnNames []string, limit int) (*copyFrom, int) { func (tx *DelayedTx) findPrevCopyFrom(tableName []string, columnNames []string, limit int) (*copyFrom, int) {
tx.RLock()
defer tx.RUnlock()
for pos, count := len(tx.cache)-1, 0; pos >= 0 && count < limit; pos, count = pos-1, count+1 { for pos, count := len(tx.cache)-1, 0; pos >= 0 && count < limit; pos, count = pos-1, count+1 {
prevCopy, ok := tx.cache[pos].(*copyFrom) prevCopy, ok := tx.cache[pos].(*copyFrom)
if ok && prevCopy.matches(tableName, columnNames) { if ok && prevCopy.matches(tableName, columnNames) {
@ -59,15 +65,19 @@ func (tx *DelayedTx) CopyFrom(ctx context.Context, tableName []string, columnNam
"current", len(prevCopy.rows), "new", len(rows), "distance", distance) "current", len(prevCopy.rows), "new", len(rows), "distance", distance)
prevCopy.appendRows(rows) prevCopy.appendRows(rows)
} else { } else {
tx.Lock()
tx.cache = append(tx.cache, &copyFrom{tableName, columnNames, rows}) tx.cache = append(tx.cache, &copyFrom{tableName, columnNames, rows})
tx.Unlock()
} }
return 0, nil return 0, nil
} }
func (tx *DelayedTx) Exec(ctx context.Context, sql string, args ...interface{}) (Result, error) { func (tx *DelayedTx) Exec(ctx context.Context, sql string, args ...interface{}) (Result, error) {
tx.Lock()
tx.cache = append(tx.cache, cachedStmt{sql, args}) tx.cache = append(tx.cache, cachedStmt{sql, args})
return nil, nil defer tx.Unlock()
return result(0), nil
} }
func (tx *DelayedTx) Commit(ctx context.Context) error { func (tx *DelayedTx) Commit(ctx context.Context) error {
@ -85,6 +95,8 @@ func (tx *DelayedTx) Commit(ctx context.Context) error {
rollback(ctx, base) rollback(ctx, base)
} }
}() }()
tx.Lock()
defer tx.Unlock()
for _, item := range tx.cache { for _, item := range tx.cache {
switch item := item.(type) { switch item := item.(type) {
case *copyFrom: case *copyFrom:
@ -105,6 +117,13 @@ func (tx *DelayedTx) Commit(ctx context.Context) error {
} }
func (tx *DelayedTx) Rollback(ctx context.Context) error { func (tx *DelayedTx) Rollback(ctx context.Context) error {
tx.Lock()
defer tx.Unlock()
tx.cache = nil tx.cache = nil
return nil return nil
} }
// RowsAffected satisfies sql.Result
func (r result) RowsAffected() (int64, error) {
return int64(r), nil
}