detect unsafe code uses

This commit is contained in:
Andrew Jackson (Ajax) 2023-12-07 16:01:28 -06:00
parent 0e49673c49
commit 1e09e1e966
2 changed files with 29 additions and 22 deletions

View File

@ -10,6 +10,7 @@ import (
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
logging "github.com/ipfs/go-log/v2" logging "github.com/ipfs/go-log/v2"
@ -33,6 +34,8 @@ type DB struct {
cfg *pgxpool.Config cfg *pgxpool.Config
schema string schema string
hostnames []string hostnames []string
BTFPOnce sync.Once
BTFP uintptr
} }
var logger = logging.Logger("harmonydb") var logger = logging.Logger("harmonydb")

View File

@ -9,6 +9,7 @@ import (
"github.com/jackc/pgerrcode" "github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
"github.com/samber/lo"
) )
var inTxErr = errors.New("Cannot use a non-transaction func in a transaction") var inTxErr = errors.New("Cannot use a non-transaction func in a transaction")
@ -25,7 +26,7 @@ type rawStringOnly string
// Note, for CREATE & DROP please keep these permanent and express // Note, for CREATE & DROP please keep these permanent and express
// them in the ./sql/ files (next number). // them in the ./sql/ files (next number).
func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) (count int, err error) { func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) (count int, err error) {
if usedInTransaction() { if db.usedInTransaction() {
return 0, inTxErr return 0, inTxErr
} }
res, err := db.pgx.Exec(ctx, string(sql), arguments...) res, err := db.pgx.Exec(ctx, string(sql), arguments...)
@ -61,7 +62,7 @@ type Query struct {
// fmt.Println(id, name) // fmt.Println(id, name)
// } // }
func (db *DB) Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) { func (db *DB) Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) {
if usedInTransaction() { if db.usedInTransaction() {
return &Query{}, inTxErr return &Query{}, inTxErr
} }
q, err := db.pgx.Query(ctx, string(sql), arguments...) q, err := db.pgx.Query(ctx, string(sql), arguments...)
@ -76,7 +77,8 @@ type Row interface {
} }
type rowErr struct{} type rowErr struct{}
func (rowErr) Scan(..any) error { return inTxErr }
func (rowErr) Scan(_ ...any) error { return inTxErr }
// QueryRow gets 1 row using column order matching. // QueryRow gets 1 row using column order matching.
// This is a timesaver for the special case of wanting the first row returned only. // This is a timesaver for the special case of wanting the first row returned only.
@ -86,10 +88,10 @@ func (rowErr) Scan(..any) error { return inTxErr }
// var ID = 123 // var ID = 123
// err := db.QueryRow(ctx, "SELECT name, pet FROM users WHERE ID=?", ID).Scan(&name, &pet) // err := db.QueryRow(ctx, "SELECT name, pet FROM users WHERE ID=?", ID).Scan(&name, &pet)
func (db *DB) QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row { func (db *DB) QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row {
if usedInTransaction() { if db.usedInTransaction() {
return rowErr{} return rowErr{}
} }
return db.pgx.QueryRow(ctx, string(sql), arguments...) return db.pgx.QueryRow(ctx, string(sql), arguments...)
} }
/* /*
@ -107,7 +109,7 @@ Ex:
err := db.Select(ctx, &users, "SELECT name, id, tel_no FROM customers WHERE pet=?", pet) err := db.Select(ctx, &users, "SELECT name, id, tel_no FROM customers WHERE pet=?", pet)
*/ */
func (db *DB) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error { func (db *DB) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
if usedInTransaction() { if db.usedInTransaction() {
return inTxErr return inTxErr
} }
return pgxscan.Select(ctx, db.pgx, sliceOfStructPtr, string(sql), arguments...) return pgxscan.Select(ctx, db.pgx, sliceOfStructPtr, string(sql), arguments...)
@ -118,29 +120,31 @@ type Tx struct {
ctx context.Context ctx context.Context
} }
// usedInTransaction is a helper to prevent nesting transactions // usedInTransaction is a helper to prevent nesting transactions
// & non-transaction calls in transactions. In the case of a stack read error, // & non-transaction calls in transactions. It only checks 20 frames.
// it will return false, so only use true for a course of action. // Fast: This memory should all be in CPU Caches.
func usedInTransaction() bool { func (db *DB) usedInTransaction() bool {
ok := true var framePtrs = (&[20]uintptr{})[:]
fn := "" runtime.Callers(3, framePtrs)
for v:=2; ok; v++ { return lo.Contains(framePtrs, db.BTFP)
_,_,fn,ok = runtime.Caller(v)
if strings.Contains(fn, "BeginTransaction") {
return true
}
}
return false
} }
// BeginTransaction is how you can access transactions using this library. // BeginTransaction is how you can access transactions using this library.
// The entire transaction happens in the function passed in. // The entire transaction happens in the function passed in.
// The return must be true or a rollback will occur. // The return must be true or a rollback will occur.
// Be sure to test the error for IsErrSerialization() if you want to retry // Be sure to test the error for IsErrSerialization() if you want to retry
// when there is a DB serialization error. //
// when there is a DB serialization error.
//
//go:noinline
func (db *DB) BeginTransaction(ctx context.Context, f func(*Tx) (commit bool, err error)) (didCommit bool, retErr error) { func (db *DB) BeginTransaction(ctx context.Context, f func(*Tx) (commit bool, err error)) (didCommit bool, retErr error) {
if usedInTransaction() { db.BTFPOnce.Do(func() {
return 0, inTxErr fp := make([]uintptr, 20)
runtime.Callers(1, fp)
db.BTFP = fp[0]
})
if db.usedInTransaction() {
return false, inTxErr
} }
tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{}) tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{})
if err != nil { if err != nil {