harmonydb: better API

This commit is contained in:
Andrew Jackson (Ajax) 2023-07-18 14:51:26 -07:00
parent 724bf76146
commit d39e699e1f
3 changed files with 31 additions and 35 deletions

View File

@ -27,7 +27,7 @@ func TestCrud(t *testing.T) {
withSetup(t, func(miner *kit.TestMiner) { withSetup(t, func(miner *kit.TestMiner) {
cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB
err := cdb.Exec(ctx, ` _, err := cdb.Exec(ctx, `
INSERT INTO INSERT INTO
itest_scratch (some_int, content) itest_scratch (some_int, content)
VALUES VALUES
@ -65,11 +65,11 @@ func TestTransaction(t *testing.T) {
withSetup(t, func(miner *kit.TestMiner) { withSetup(t, func(miner *kit.TestMiner) {
cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB
if err := cdb.Exec(ctx, "INSERT INTO itest_scratch (some_int) VALUES (4), (5), (6)"); err != nil { if _, err := cdb.Exec(ctx, "INSERT INTO itest_scratch (some_int) VALUES (4), (5), (6)"); err != nil {
t.Fatal("E0", err) t.Fatal("E0", err)
} }
err := cdb.BeginTransaction(ctx, func(tx *harmonydb.Transaction) (commit bool) { _, err := cdb.BeginTransaction(ctx, func(tx *harmonydb.Tx) (commit bool) {
if err := tx.Exec(ctx, "INSERT INTO itest_scratch (some_int) VALUES (7), (8), (9)"); err != nil { if _, err := tx.Exec("INSERT INTO itest_scratch (some_int) VALUES (7), (8), (9)"); err != nil {
t.Fatal("E1", err) t.Fatal("E1", err)
} }
@ -84,7 +84,7 @@ func TestTransaction(t *testing.T) {
// sum2 is from INSIDE the transaction, so the updated value. // sum2 is from INSIDE the transaction, so the updated value.
var sum2 int var sum2 int
if err := tx.QueryRow(ctx, "SELECT SUM(some_int) FROM itest_scratch").Scan(&sum2); err != nil { if err := tx.QueryRow("SELECT SUM(some_int) FROM itest_scratch").Scan(&sum2); err != nil {
t.Fatal("E3", err) t.Fatal("E3", err)
} }
if sum2 != 4+5+6+7+8+9 { if sum2 != 4+5+6+7+8+9 {
@ -126,7 +126,7 @@ func TestPartialWalk(t *testing.T) {
withSetup(t, func(miner *kit.TestMiner) { withSetup(t, func(miner *kit.TestMiner) {
cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB
if err := cdb.Exec(ctx, ` if _, err := cdb.Exec(ctx, `
INSERT INTO INSERT INTO
itest_scratch (content, some_int) itest_scratch (content, some_int)
VALUES VALUES

View File

@ -204,7 +204,7 @@ var fs embed.FS
func (db *DB) upgrade() error { func (db *DB) upgrade() error {
// Does the version table exist? if not, make it. // Does the version table exist? if not, make it.
// NOTE: This cannot change except via the next sql file. // NOTE: This cannot change except via the next sql file.
err := db.Exec(context.Background(), `CREATE TABLE IF NOT EXISTS base ( _, err := db.Exec(context.Background(), `CREATE TABLE IF NOT EXISTS base (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
entry CHAR(12), entry CHAR(12),
applied TIMESTAMP DEFAULT current_timestamp applied TIMESTAMP DEFAULT current_timestamp
@ -256,7 +256,7 @@ func (db *DB) upgrade() error {
} }
// Mark Completed. // Mark Completed.
err = db.Exec(context.Background(), "INSERT INTO base (entry) VALUES ($1)", name) _, err = db.Exec(context.Background(), "INSERT INTO base (entry) VALUES ($1)", name)
if err != nil { if err != nil {
db.log("Cannot update base: " + err.Error()) db.log("Cannot update base: " + err.Error())
return fmt.Errorf("cannot insert into base: %w", err) return fmt.Errorf("cannot insert into base: %w", err)

View File

@ -7,15 +7,6 @@ import (
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
) )
type Intf interface {
ITestDeleteAll()
Exec(ctx context.Context, sql rawStringOnly, arguments ...any) error
Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error)
QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row
Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error
BeginTransaction(ctx context.Context, f func(t *Transaction) (commit bool)) (retErr error)
}
// rawStringOnly is _intentionally_private_ to force only basic strings in SQL queries. // rawStringOnly is _intentionally_private_ to force only basic strings in SQL queries.
// In any package, raw strings will satisfy compilation. Ex: // In any package, raw strings will satisfy compilation. Ex:
// //
@ -27,9 +18,9 @@ type rawStringOnly string
// Exec executes changes (INSERT, DELETE, or UPDATE). // Exec executes changes (INSERT, DELETE, or UPDATE).
// Note, for CREATE & DROP please keep these permanent and express // Note, for CREATE & DROP please keep these permanent and express
// them in the ./sql/ files (next number). // them in the ./sql/ files (next number).
func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) error { func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) (count int, err error) {
_, err := db.pgx.Exec(ctx, string(sql), arguments...) res, err := db.pgx.Exec(ctx, string(sql), arguments...)
return err return int(res.RowsAffected()), err
} }
type Qry interface { type Qry interface {
@ -101,17 +92,18 @@ func (db *DB) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnl
return pgxscan.Select(ctx, db.pgx, sliceOfStructPtr, string(sql), arguments...) return pgxscan.Select(ctx, db.pgx, sliceOfStructPtr, string(sql), arguments...)
} }
type Transaction struct { type Tx struct {
pgx.Tx pgx.Tx
ctx context.Context
} }
// BeginTransaction is how you can access transactions using this library. // BeginTransaction is how you can access transactions using this library.
// The entire transaction happens in the function passed in. // The entire transaction happens in the function passed in.
// The return must be true or a rollback will occur. // The return must be true or a rollback will occur.
func (db *DB) BeginTransaction(ctx context.Context, f func(t *Transaction) (commit bool)) (retErr error) { func (db *DB) BeginTransaction(ctx context.Context, f func(*Tx) (commit bool)) (didCommit bool, retErr error) {
tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{}) tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{})
if err != nil { if err != nil {
return err return false, err
} }
var commit bool var commit bool
defer func() { // Panic clean-up. defer func() { // Panic clean-up.
@ -119,31 +111,35 @@ func (db *DB) BeginTransaction(ctx context.Context, f func(t *Transaction) (comm
retErr = tx.Rollback(ctx) retErr = tx.Rollback(ctx)
} }
}() }()
commit = f(&Transaction{tx}) commit = f(&Tx{tx, ctx})
if commit { if commit {
return tx.Commit(ctx) err := tx.Commit(ctx)
if err != nil {
return false, err
}
return true, nil
} }
return nil return false, nil
} }
// Exec in a transaction. // Exec in a transaction.
func (t *Transaction) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) error { func (t *Tx) Exec(sql rawStringOnly, arguments ...any) (count int, err error) {
_, err := t.Tx.Exec(ctx, string(sql), arguments...) res, err := t.Tx.Exec(t.ctx, string(sql), arguments...)
return err return int(res.RowsAffected()), err
} }
// Query in a transaction. // Query in a transaction.
func (t *Transaction) Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) { func (t *Tx) Query(sql rawStringOnly, arguments ...any) (*Query, error) {
q, err := t.Tx.Query(ctx, string(sql), arguments...) q, err := t.Tx.Query(t.ctx, string(sql), arguments...)
return &Query{q}, err return &Query{q}, err
} }
// QueryRow in a transaction. // QueryRow in a transaction.
func (t *Transaction) QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row { func (t *Tx) QueryRow(sql rawStringOnly, arguments ...any) Row {
return t.Tx.QueryRow(ctx, string(sql), arguments...) return t.Tx.QueryRow(t.ctx, string(sql), arguments...)
} }
// Select in a transaction. // Select in a transaction.
func (t *Transaction) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error { func (t *Tx) Select(sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
return pgxscan.Select(ctx, t.Tx, sliceOfStructPtr, string(sql), arguments...) return pgxscan.Select(t.ctx, t.Tx, sliceOfStructPtr, string(sql), arguments...)
} }