detect unsafe code uses
This commit is contained in:
parent
0e49673c49
commit
1e09e1e966
@ -10,6 +10,7 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
logging "github.com/ipfs/go-log/v2"
|
logging "github.com/ipfs/go-log/v2"
|
||||||
@ -33,6 +34,8 @@ type DB struct {
|
|||||||
cfg *pgxpool.Config
|
cfg *pgxpool.Config
|
||||||
schema string
|
schema string
|
||||||
hostnames []string
|
hostnames []string
|
||||||
|
BTFPOnce sync.Once
|
||||||
|
BTFP uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
var logger = logging.Logger("harmonydb")
|
var logger = logging.Logger("harmonydb")
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/jackc/pgerrcode"
|
"github.com/jackc/pgerrcode"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/samber/lo"
|
||||||
)
|
)
|
||||||
|
|
||||||
var inTxErr = errors.New("Cannot use a non-transaction func in a transaction")
|
var inTxErr = errors.New("Cannot use a non-transaction func in a transaction")
|
||||||
@ -25,7 +26,7 @@ type rawStringOnly string
|
|||||||
// Note, for CREATE & DROP please keep these permanent and express
|
// Note, for CREATE & DROP please keep these permanent and express
|
||||||
// them in the ./sql/ files (next number).
|
// them in the ./sql/ files (next number).
|
||||||
func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) (count int, err error) {
|
func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) (count int, err error) {
|
||||||
if usedInTransaction() {
|
if db.usedInTransaction() {
|
||||||
return 0, inTxErr
|
return 0, inTxErr
|
||||||
}
|
}
|
||||||
res, err := db.pgx.Exec(ctx, string(sql), arguments...)
|
res, err := db.pgx.Exec(ctx, string(sql), arguments...)
|
||||||
@ -61,7 +62,7 @@ type Query struct {
|
|||||||
// fmt.Println(id, name)
|
// fmt.Println(id, name)
|
||||||
// }
|
// }
|
||||||
func (db *DB) Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) {
|
func (db *DB) Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) {
|
||||||
if usedInTransaction() {
|
if db.usedInTransaction() {
|
||||||
return &Query{}, inTxErr
|
return &Query{}, inTxErr
|
||||||
}
|
}
|
||||||
q, err := db.pgx.Query(ctx, string(sql), arguments...)
|
q, err := db.pgx.Query(ctx, string(sql), arguments...)
|
||||||
@ -76,7 +77,8 @@ type Row interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type rowErr struct{}
|
type rowErr struct{}
|
||||||
func (rowErr) Scan(..any) error { return inTxErr }
|
|
||||||
|
func (rowErr) Scan(_ ...any) error { return inTxErr }
|
||||||
|
|
||||||
// QueryRow gets 1 row using column order matching.
|
// QueryRow gets 1 row using column order matching.
|
||||||
// This is a timesaver for the special case of wanting the first row returned only.
|
// This is a timesaver for the special case of wanting the first row returned only.
|
||||||
@ -86,10 +88,10 @@ func (rowErr) Scan(..any) error { return inTxErr }
|
|||||||
// var ID = 123
|
// var ID = 123
|
||||||
// err := db.QueryRow(ctx, "SELECT name, pet FROM users WHERE ID=?", ID).Scan(&name, &pet)
|
// err := db.QueryRow(ctx, "SELECT name, pet FROM users WHERE ID=?", ID).Scan(&name, &pet)
|
||||||
func (db *DB) QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row {
|
func (db *DB) QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row {
|
||||||
if usedInTransaction() {
|
if db.usedInTransaction() {
|
||||||
return rowErr{}
|
return rowErr{}
|
||||||
}
|
}
|
||||||
return db.pgx.QueryRow(ctx, string(sql), arguments...)
|
return db.pgx.QueryRow(ctx, string(sql), arguments...)
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -107,7 +109,7 @@ Ex:
|
|||||||
err := db.Select(ctx, &users, "SELECT name, id, tel_no FROM customers WHERE pet=?", pet)
|
err := db.Select(ctx, &users, "SELECT name, id, tel_no FROM customers WHERE pet=?", pet)
|
||||||
*/
|
*/
|
||||||
func (db *DB) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
|
func (db *DB) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
|
||||||
if usedInTransaction() {
|
if db.usedInTransaction() {
|
||||||
return inTxErr
|
return inTxErr
|
||||||
}
|
}
|
||||||
return pgxscan.Select(ctx, db.pgx, sliceOfStructPtr, string(sql), arguments...)
|
return pgxscan.Select(ctx, db.pgx, sliceOfStructPtr, string(sql), arguments...)
|
||||||
@ -118,29 +120,31 @@ type Tx struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// usedInTransaction is a helper to prevent nesting transactions
|
// usedInTransaction is a helper to prevent nesting transactions
|
||||||
// & non-transaction calls in transactions. In the case of a stack read error,
|
// & non-transaction calls in transactions. It only checks 20 frames.
|
||||||
// it will return false, so only use true for a course of action.
|
// Fast: This memory should all be in CPU Caches.
|
||||||
func usedInTransaction() bool {
|
func (db *DB) usedInTransaction() bool {
|
||||||
ok := true
|
var framePtrs = (&[20]uintptr{})[:]
|
||||||
fn := ""
|
runtime.Callers(3, framePtrs)
|
||||||
for v:=2; ok; v++ {
|
return lo.Contains(framePtrs, db.BTFP)
|
||||||
_,_,fn,ok = runtime.Caller(v)
|
|
||||||
if strings.Contains(fn, "BeginTransaction") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeginTransaction is how you can access transactions using this library.
|
// BeginTransaction is how you can access transactions using this library.
|
||||||
// The entire transaction happens in the function passed in.
|
// The entire transaction happens in the function passed in.
|
||||||
// The return must be true or a rollback will occur.
|
// The return must be true or a rollback will occur.
|
||||||
// Be sure to test the error for IsErrSerialization() if you want to retry
|
// Be sure to test the error for IsErrSerialization() if you want to retry
|
||||||
// when there is a DB serialization error.
|
//
|
||||||
|
// when there is a DB serialization error.
|
||||||
|
//
|
||||||
|
//go:noinline
|
||||||
func (db *DB) BeginTransaction(ctx context.Context, f func(*Tx) (commit bool, err error)) (didCommit bool, retErr error) {
|
func (db *DB) BeginTransaction(ctx context.Context, f func(*Tx) (commit bool, err error)) (didCommit bool, retErr error) {
|
||||||
if usedInTransaction() {
|
db.BTFPOnce.Do(func() {
|
||||||
return 0, inTxErr
|
fp := make([]uintptr, 20)
|
||||||
|
runtime.Callers(1, fp)
|
||||||
|
db.BTFP = fp[0]
|
||||||
|
})
|
||||||
|
if db.usedInTransaction() {
|
||||||
|
return false, inTxErr
|
||||||
}
|
}
|
||||||
tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{})
|
tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user