harmonydb: better API

This commit is contained in:
Andrew Jackson (Ajax) 2023-07-18 14:51:26 -07:00
parent 724bf76146
commit d39e699e1f
3 changed files with 31 additions and 35 deletions

View File

@ -27,7 +27,7 @@ func TestCrud(t *testing.T) {
withSetup(t, func(miner *kit.TestMiner) {
cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB
err := cdb.Exec(ctx, `
_, err := cdb.Exec(ctx, `
INSERT INTO
itest_scratch (some_int, content)
VALUES
@ -65,11 +65,11 @@ func TestTransaction(t *testing.T) {
withSetup(t, func(miner *kit.TestMiner) {
cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB
if err := cdb.Exec(ctx, "INSERT INTO itest_scratch (some_int) VALUES (4), (5), (6)"); err != nil {
if _, err := cdb.Exec(ctx, "INSERT INTO itest_scratch (some_int) VALUES (4), (5), (6)"); err != nil {
t.Fatal("E0", err)
}
err := cdb.BeginTransaction(ctx, func(tx *harmonydb.Transaction) (commit bool) {
if err := tx.Exec(ctx, "INSERT INTO itest_scratch (some_int) VALUES (7), (8), (9)"); err != nil {
_, err := cdb.BeginTransaction(ctx, func(tx *harmonydb.Tx) (commit bool) {
if _, err := tx.Exec("INSERT INTO itest_scratch (some_int) VALUES (7), (8), (9)"); err != nil {
t.Fatal("E1", err)
}
@ -84,7 +84,7 @@ func TestTransaction(t *testing.T) {
// sum2 is from INSIDE the transaction, so the updated value.
var sum2 int
if err := tx.QueryRow(ctx, "SELECT SUM(some_int) FROM itest_scratch").Scan(&sum2); err != nil {
if err := tx.QueryRow("SELECT SUM(some_int) FROM itest_scratch").Scan(&sum2); err != nil {
t.Fatal("E3", err)
}
if sum2 != 4+5+6+7+8+9 {
@ -126,7 +126,7 @@ func TestPartialWalk(t *testing.T) {
withSetup(t, func(miner *kit.TestMiner) {
cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB
if err := cdb.Exec(ctx, `
if _, err := cdb.Exec(ctx, `
INSERT INTO
itest_scratch (content, some_int)
VALUES

View File

@ -204,7 +204,7 @@ var fs embed.FS
func (db *DB) upgrade() error {
// Does the version table exist? if not, make it.
// NOTE: This cannot change except via the next sql file.
err := db.Exec(context.Background(), `CREATE TABLE IF NOT EXISTS base (
_, err := db.Exec(context.Background(), `CREATE TABLE IF NOT EXISTS base (
id SERIAL PRIMARY KEY,
entry CHAR(12),
applied TIMESTAMP DEFAULT current_timestamp
@ -256,7 +256,7 @@ func (db *DB) upgrade() error {
}
// Mark Completed.
err = db.Exec(context.Background(), "INSERT INTO base (entry) VALUES ($1)", name)
_, err = db.Exec(context.Background(), "INSERT INTO base (entry) VALUES ($1)", name)
if err != nil {
db.log("Cannot update base: " + err.Error())
return fmt.Errorf("cannot insert into base: %w", err)

View File

@ -7,15 +7,6 @@ import (
"github.com/jackc/pgx/v5"
)
type Intf interface {
ITestDeleteAll()
Exec(ctx context.Context, sql rawStringOnly, arguments ...any) error
Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error)
QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row
Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error
BeginTransaction(ctx context.Context, f func(t *Transaction) (commit bool)) (retErr error)
}
// rawStringOnly is _intentionally_private_ to force only basic strings in SQL queries.
// In any package, raw strings will satisfy compilation. Ex:
//
@ -27,9 +18,9 @@ type rawStringOnly string
// Exec executes changes (INSERT, DELETE, or UPDATE).
// Note, for CREATE & DROP please keep these permanent and express
// them in the ./sql/ files (next number).
func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) error {
_, err := db.pgx.Exec(ctx, string(sql), arguments...)
return err
func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) (count int, err error) {
res, err := db.pgx.Exec(ctx, string(sql), arguments...)
return int(res.RowsAffected()), err
}
type Qry interface {
@ -101,17 +92,18 @@ func (db *DB) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnl
return pgxscan.Select(ctx, db.pgx, sliceOfStructPtr, string(sql), arguments...)
}
type Transaction struct {
type Tx struct {
pgx.Tx
ctx context.Context
}
// BeginTransaction is how you can access transactions using this library.
// The entire transaction happens in the function passed in.
// The return must be true or a rollback will occur.
func (db *DB) BeginTransaction(ctx context.Context, f func(t *Transaction) (commit bool)) (retErr error) {
func (db *DB) BeginTransaction(ctx context.Context, f func(*Tx) (commit bool)) (didCommit bool, retErr error) {
tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return err
return false, err
}
var commit bool
defer func() { // Panic clean-up.
@ -119,31 +111,35 @@ func (db *DB) BeginTransaction(ctx context.Context, f func(t *Transaction) (comm
retErr = tx.Rollback(ctx)
}
}()
commit = f(&Transaction{tx})
commit = f(&Tx{tx, ctx})
if commit {
return tx.Commit(ctx)
err := tx.Commit(ctx)
if err != nil {
return false, err
}
return true, nil
}
return nil
return false, nil
}
// Exec in a transaction.
func (t *Transaction) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) error {
_, err := t.Tx.Exec(ctx, string(sql), arguments...)
return err
func (t *Tx) Exec(sql rawStringOnly, arguments ...any) (count int, err error) {
res, err := t.Tx.Exec(t.ctx, string(sql), arguments...)
return int(res.RowsAffected()), err
}
// Query in a transaction.
func (t *Transaction) Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) {
q, err := t.Tx.Query(ctx, string(sql), arguments...)
func (t *Tx) Query(sql rawStringOnly, arguments ...any) (*Query, error) {
q, err := t.Tx.Query(t.ctx, string(sql), arguments...)
return &Query{q}, err
}
// QueryRow in a transaction.
func (t *Transaction) QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row {
return t.Tx.QueryRow(ctx, string(sql), arguments...)
func (t *Tx) QueryRow(sql rawStringOnly, arguments ...any) Row {
return t.Tx.QueryRow(t.ctx, string(sql), arguments...)
}
// Select in a transaction.
func (t *Transaction) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
return pgxscan.Select(ctx, t.Tx, sliceOfStructPtr, string(sql), arguments...)
func (t *Tx) Select(sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
return pgxscan.Select(t.ctx, t.Tx, sliceOfStructPtr, string(sql), arguments...)
}