lotus/lib/harmony/harmonydb/userfuncs.go

277 lines
7.3 KiB
Go
Raw Normal View History

2023-07-14 23:05:49 +00:00
package harmonydb
import (
"context"
2023-10-26 22:19:39 +00:00
"errors"
2024-04-17 19:18:10 +00:00
"fmt"
2023-12-07 21:32:35 +00:00
"runtime"
"time"
2023-07-14 23:05:49 +00:00
2024-04-17 19:18:10 +00:00
"github.com/georgysavva/scany/v2/dbscan"
2023-10-26 22:19:39 +00:00
"github.com/jackc/pgerrcode"
2023-12-07 22:01:28 +00:00
"github.com/samber/lo"
2024-04-17 19:18:10 +00:00
"github.com/yugabyte/pgx/v5"
"github.com/yugabyte/pgx/v5/pgconn"
2023-07-14 23:05:49 +00:00
)
2024-04-17 19:18:10 +00:00
var errTx = errors.New("cannot use a non-transaction func in a transaction")
2023-12-07 21:32:35 +00:00
2023-07-14 23:05:49 +00:00
// rawStringOnly is _intentionally_private_ to force only basic strings in SQL queries.
// In any package, raw strings will satisfy compilation. Ex:
//
// harmonydb.Exec("INSERT INTO version (number) VALUES (1)")
//
// This prevents SQL injection attacks where the input contains query fragments.
type rawStringOnly string
// Exec executes changes (INSERT, DELETE, or UPDATE).
// Note, for CREATE & DROP please keep these permanent and express
// them in the ./sql/ files (next number).
2023-07-18 21:51:26 +00:00
func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) (count int, err error) {
2023-12-07 22:01:28 +00:00
if db.usedInTransaction() {
2023-12-09 17:20:41 +00:00
return 0, errTx
2023-12-07 21:32:35 +00:00
}
2023-07-18 21:51:26 +00:00
res, err := db.pgx.Exec(ctx, string(sql), arguments...)
return int(res.RowsAffected()), err
2023-07-14 23:05:49 +00:00
}
type Qry interface {
Next() bool
Err() error
Close()
Scan(...any) error
Values() ([]any, error)
}
2024-04-17 19:18:10 +00:00
// Query offers Next/Err/Close/Scan/Values
2023-07-14 23:05:49 +00:00
type Query struct {
Qry
}
// Query allows iterating returned values to save memory consumption
// with the downside of needing to `defer q.Close()`. For a simpler interface,
// try Select()
// Next() must be called to advance the row cursor, including the first time:
// Ex:
// q, err := db.Query(ctx, "SELECT id, name FROM users")
// handleError(err)
// defer q.Close()
//
// for q.Next() {
// var id int
// var name string
// handleError(q.Scan(&id, &name))
// fmt.Println(id, name)
// }
func (db *DB) Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) {
2023-12-07 22:01:28 +00:00
if db.usedInTransaction() {
2023-12-09 17:20:41 +00:00
return &Query{}, errTx
2023-12-07 21:32:35 +00:00
}
2023-07-14 23:05:49 +00:00
q, err := db.pgx.Query(ctx, string(sql), arguments...)
return &Query{q}, err
}
2024-04-17 19:18:10 +00:00
// StructScan allows scanning a single row into a struct.
// This improves efficiency of processing large result sets
// by avoiding the need to allocate a slice of structs.
2023-07-14 23:05:49 +00:00
func (q *Query) StructScan(s any) error {
2024-04-17 19:18:10 +00:00
return dbscan.ScanRow(s, dbscanRows{q.Qry.(pgx.Rows)})
2023-07-14 23:05:49 +00:00
}
type Row interface {
Scan(...any) error
}
2023-12-07 21:32:35 +00:00
type rowErr struct{}
2023-12-07 22:01:28 +00:00
2023-12-09 17:20:41 +00:00
func (rowErr) Scan(_ ...any) error { return errTx }
2023-12-07 21:32:35 +00:00
2023-07-14 23:05:49 +00:00
// QueryRow gets 1 row using column order matching.
// This is a timesaver for the special case of wanting the first row returned only.
// EX:
//
// var name, pet string
// var ID = 123
// err := db.QueryRow(ctx, "SELECT name, pet FROM users WHERE ID=?", ID).Scan(&name, &pet)
func (db *DB) QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row {
2023-12-07 22:01:28 +00:00
if db.usedInTransaction() {
2023-12-07 21:32:35 +00:00
return rowErr{}
}
2023-12-07 22:01:28 +00:00
return db.pgx.QueryRow(ctx, string(sql), arguments...)
2023-07-14 23:05:49 +00:00
}
2024-04-17 19:18:10 +00:00
type dbscanRows struct {
pgx.Rows
}
func (d dbscanRows) Close() error {
d.Rows.Close()
return nil
}
func (d dbscanRows) Columns() ([]string, error) {
return lo.Map(d.Rows.FieldDescriptions(), func(fd pgconn.FieldDescription, _ int) string {
return fd.Name
}), nil
}
2023-07-14 23:05:49 +00:00
/*
Select multiple rows into a slice using name matching
Ex:
type user struct {
Name string
ID int
Number string `db:"tel_no"`
}
var users []user
pet := "cat"
err := db.Select(ctx, &users, "SELECT name, id, tel_no FROM customers WHERE pet=?", pet)
*/
2023-08-22 14:43:50 +00:00
func (db *DB) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
2023-12-07 22:01:28 +00:00
if db.usedInTransaction() {
2023-12-09 17:20:41 +00:00
return errTx
2023-12-07 21:32:35 +00:00
}
2024-04-17 19:18:10 +00:00
rows, err := db.pgx.Query(ctx, string(sql), arguments...)
if err != nil {
return err
}
defer rows.Close()
return dbscan.ScanAll(sliceOfStructPtr, dbscanRows{rows})
2023-07-14 23:05:49 +00:00
}
2023-07-18 21:51:26 +00:00
type Tx struct {
2023-07-14 23:05:49 +00:00
pgx.Tx
2023-07-18 21:51:26 +00:00
ctx context.Context
2023-07-14 23:05:49 +00:00
}
2023-12-07 22:01:28 +00:00
// usedInTransaction is a helper to prevent nesting transactions
// & non-transaction calls in transactions. It only checks 20 frames.
// Fast: This memory should all be in CPU Caches.
func (db *DB) usedInTransaction() bool {
2023-12-11 16:50:49 +00:00
var framePtrs = (&[20]uintptr{})[:] // 20 can be stack-local (no alloc)
framePtrs = framePtrs[:runtime.Callers(3, framePtrs)] // skip past our caller.
2023-12-11 17:31:38 +00:00
return lo.Contains(framePtrs, db.BTFP.Load()) // Unsafe read @ beginTx overlap, but 'return false' is correct there.
2023-12-07 21:32:35 +00:00
}
type TransactionOptions struct {
RetrySerializationError bool
InitialSerializationErrorRetryWait time.Duration
}
type TransactionOption func(*TransactionOptions)
2024-02-21 13:19:48 +00:00
func OptionRetry() TransactionOption {
return func(o *TransactionOptions) {
o.RetrySerializationError = true
}
}
2024-02-21 13:19:48 +00:00
func OptionSerialRetryTime(d time.Duration) TransactionOption {
return func(o *TransactionOptions) {
o.InitialSerializationErrorRetryWait = d
}
}
2023-07-14 23:05:49 +00:00
// BeginTransaction is how you can access transactions using this library.
// The entire transaction happens in the function passed in.
// The return must be true or a rollback will occur.
2023-12-07 21:32:35 +00:00
// Be sure to test the error for IsErrSerialization() if you want to retry
2023-12-07 22:01:28 +00:00
//
// when there is a DB serialization error.
//
//go:noinline
func (db *DB) BeginTransaction(ctx context.Context, f func(*Tx) (commit bool, err error), opt ...TransactionOption) (didCommit bool, retErr error) {
2023-12-07 22:01:28 +00:00
db.BTFPOnce.Do(func() {
fp := make([]uintptr, 20)
runtime.Callers(1, fp)
2023-12-11 17:31:38 +00:00
db.BTFP.Store(fp[0])
2023-12-07 22:01:28 +00:00
})
if db.usedInTransaction() {
2023-12-09 17:20:41 +00:00
return false, errTx
2023-12-07 21:32:35 +00:00
}
opts := TransactionOptions{
RetrySerializationError: false,
InitialSerializationErrorRetryWait: 10 * time.Millisecond,
}
for _, o := range opt {
o(&opts)
}
retry:
comm, err := db.transactionInner(ctx, f)
if err != nil && opts.RetrySerializationError && IsErrSerialization(err) {
time.Sleep(opts.InitialSerializationErrorRetryWait)
opts.InitialSerializationErrorRetryWait *= 2
goto retry
}
return comm, err
}
func (db *DB) transactionInner(ctx context.Context, f func(*Tx) (commit bool, err error)) (didCommit bool, retErr error) {
2023-07-14 23:05:49 +00:00
tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
2023-07-18 21:51:26 +00:00
return false, err
2023-07-14 23:05:49 +00:00
}
var commit bool
defer func() { // Panic clean-up.
if !commit {
2023-11-17 13:43:58 +00:00
if tmp := tx.Rollback(ctx); tmp != nil {
retErr = tmp
}
2023-07-14 23:05:49 +00:00
}
}()
2023-08-10 22:35:35 +00:00
commit, err = f(&Tx{tx, ctx})
if err != nil {
return false, err
}
2023-07-14 23:05:49 +00:00
if commit {
2023-11-17 13:43:58 +00:00
err = tx.Commit(ctx)
2023-07-18 21:51:26 +00:00
if err != nil {
return false, err
}
return true, nil
2023-07-14 23:05:49 +00:00
}
2023-07-18 21:51:26 +00:00
return false, nil
2023-07-14 23:05:49 +00:00
}
// Exec in a transaction.
2023-07-18 21:51:26 +00:00
func (t *Tx) Exec(sql rawStringOnly, arguments ...any) (count int, err error) {
res, err := t.Tx.Exec(t.ctx, string(sql), arguments...)
return int(res.RowsAffected()), err
2023-07-14 23:05:49 +00:00
}
// Query in a transaction.
2023-07-18 21:51:26 +00:00
func (t *Tx) Query(sql rawStringOnly, arguments ...any) (*Query, error) {
q, err := t.Tx.Query(t.ctx, string(sql), arguments...)
2023-07-14 23:05:49 +00:00
return &Query{q}, err
}
// QueryRow in a transaction.
2023-07-18 21:51:26 +00:00
func (t *Tx) QueryRow(sql rawStringOnly, arguments ...any) Row {
return t.Tx.QueryRow(t.ctx, string(sql), arguments...)
2023-07-14 23:05:49 +00:00
}
// Select in a transaction.
2023-07-18 21:51:26 +00:00
func (t *Tx) Select(sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
2024-04-17 19:18:10 +00:00
rows, err := t.Query(sql, arguments...)
if err != nil {
return fmt.Errorf("scany: query multiple result rows: %w", err)
}
defer rows.Close()
return dbscan.ScanAll(sliceOfStructPtr, dbscanRows{rows.Qry.(pgx.Rows)})
}
2023-10-26 22:19:39 +00:00
func IsErrUniqueContraint(err error) bool {
var e2 *pgconn.PgError
return errors.As(err, &e2) && e2.Code == pgerrcode.UniqueViolation
}
2023-12-07 21:32:35 +00:00
func IsErrSerialization(err error) bool {
var e2 *pgconn.PgError
return errors.As(err, &e2) && e2.Code == pgerrcode.SerializationFailure
}