diff --git a/itests/harmonydb_test.go b/itests/harmonydb_test.go index 9e09691da..b52a2aa8f 100644 --- a/itests/harmonydb_test.go +++ b/itests/harmonydb_test.go @@ -27,7 +27,7 @@ func TestCrud(t *testing.T) { withSetup(t, func(miner *kit.TestMiner) { cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB - err := cdb.Exec(ctx, ` + _, err := cdb.Exec(ctx, ` INSERT INTO itest_scratch (some_int, content) VALUES @@ -65,11 +65,11 @@ func TestTransaction(t *testing.T) { withSetup(t, func(miner *kit.TestMiner) { cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB - if err := cdb.Exec(ctx, "INSERT INTO itest_scratch (some_int) VALUES (4), (5), (6)"); err != nil { + if _, err := cdb.Exec(ctx, "INSERT INTO itest_scratch (some_int) VALUES (4), (5), (6)"); err != nil { t.Fatal("E0", err) } - err := cdb.BeginTransaction(ctx, func(tx *harmonydb.Transaction) (commit bool) { - if err := tx.Exec(ctx, "INSERT INTO itest_scratch (some_int) VALUES (7), (8), (9)"); err != nil { + _, err := cdb.BeginTransaction(ctx, func(tx *harmonydb.Tx) (commit bool) { + if _, err := tx.Exec("INSERT INTO itest_scratch (some_int) VALUES (7), (8), (9)"); err != nil { t.Fatal("E1", err) } @@ -84,7 +84,7 @@ func TestTransaction(t *testing.T) { // sum2 is from INSIDE the transaction, so the updated value. var sum2 int - if err := tx.QueryRow(ctx, "SELECT SUM(some_int) FROM itest_scratch").Scan(&sum2); err != nil { + if err := tx.QueryRow("SELECT SUM(some_int) FROM itest_scratch").Scan(&sum2); err != nil { t.Fatal("E3", err) } if sum2 != 4+5+6+7+8+9 { @@ -126,7 +126,7 @@ func TestPartialWalk(t *testing.T) { withSetup(t, func(miner *kit.TestMiner) { cdb := miner.BaseAPI.(*impl.StorageMinerAPI).HarmonyDB - if err := cdb.Exec(ctx, ` + if _, err := cdb.Exec(ctx, ` INSERT INTO itest_scratch (content, some_int) VALUES diff --git a/lib/harmony/harmonydb/harmonydb.go b/lib/harmony/harmonydb/harmonydb.go index 702a97681..fd31e7a13 100644 --- a/lib/harmony/harmonydb/harmonydb.go +++ b/lib/harmony/harmonydb/harmonydb.go @@ -204,7 +204,7 @@ var fs embed.FS func (db *DB) upgrade() error { // Does the version table exist? if not, make it. // NOTE: This cannot change except via the next sql file. - err := db.Exec(context.Background(), `CREATE TABLE IF NOT EXISTS base ( + _, err := db.Exec(context.Background(), `CREATE TABLE IF NOT EXISTS base ( id SERIAL PRIMARY KEY, entry CHAR(12), applied TIMESTAMP DEFAULT current_timestamp @@ -256,7 +256,7 @@ func (db *DB) upgrade() error { } // Mark Completed. - err = db.Exec(context.Background(), "INSERT INTO base (entry) VALUES ($1)", name) + _, err = db.Exec(context.Background(), "INSERT INTO base (entry) VALUES ($1)", name) if err != nil { db.log("Cannot update base: " + err.Error()) return fmt.Errorf("cannot insert into base: %w", err) diff --git a/lib/harmony/harmonydb/userfuncs.go b/lib/harmony/harmonydb/userfuncs.go index 8eebdd607..4d35fd8ca 100644 --- a/lib/harmony/harmonydb/userfuncs.go +++ b/lib/harmony/harmonydb/userfuncs.go @@ -7,15 +7,6 @@ import ( "github.com/jackc/pgx/v5" ) -type Intf interface { - ITestDeleteAll() - Exec(ctx context.Context, sql rawStringOnly, arguments ...any) error - Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) - QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row - Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error - BeginTransaction(ctx context.Context, f func(t *Transaction) (commit bool)) (retErr error) -} - // rawStringOnly is _intentionally_private_ to force only basic strings in SQL queries. // In any package, raw strings will satisfy compilation. Ex: // @@ -27,9 +18,9 @@ type rawStringOnly string // Exec executes changes (INSERT, DELETE, or UPDATE). // Note, for CREATE & DROP please keep these permanent and express // them in the ./sql/ files (next number). -func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) error { - _, err := db.pgx.Exec(ctx, string(sql), arguments...) - return err +func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) (count int, err error) { + res, err := db.pgx.Exec(ctx, string(sql), arguments...) + return int(res.RowsAffected()), err } type Qry interface { @@ -101,17 +92,18 @@ func (db *DB) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnl return pgxscan.Select(ctx, db.pgx, sliceOfStructPtr, string(sql), arguments...) } -type Transaction struct { +type Tx struct { pgx.Tx + ctx context.Context } // BeginTransaction is how you can access transactions using this library. // The entire transaction happens in the function passed in. // The return must be true or a rollback will occur. -func (db *DB) BeginTransaction(ctx context.Context, f func(t *Transaction) (commit bool)) (retErr error) { +func (db *DB) BeginTransaction(ctx context.Context, f func(*Tx) (commit bool)) (didCommit bool, retErr error) { tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{}) if err != nil { - return err + return false, err } var commit bool defer func() { // Panic clean-up. @@ -119,31 +111,35 @@ func (db *DB) BeginTransaction(ctx context.Context, f func(t *Transaction) (comm retErr = tx.Rollback(ctx) } }() - commit = f(&Transaction{tx}) + commit = f(&Tx{tx, ctx}) if commit { - return tx.Commit(ctx) + err := tx.Commit(ctx) + if err != nil { + return false, err + } + return true, nil } - return nil + return false, nil } // Exec in a transaction. -func (t *Transaction) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) error { - _, err := t.Tx.Exec(ctx, string(sql), arguments...) - return err +func (t *Tx) Exec(sql rawStringOnly, arguments ...any) (count int, err error) { + res, err := t.Tx.Exec(t.ctx, string(sql), arguments...) + return int(res.RowsAffected()), err } // Query in a transaction. -func (t *Transaction) Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) { - q, err := t.Tx.Query(ctx, string(sql), arguments...) +func (t *Tx) Query(sql rawStringOnly, arguments ...any) (*Query, error) { + q, err := t.Tx.Query(t.ctx, string(sql), arguments...) return &Query{q}, err } // QueryRow in a transaction. -func (t *Transaction) QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row { - return t.Tx.QueryRow(ctx, string(sql), arguments...) +func (t *Tx) QueryRow(sql rawStringOnly, arguments ...any) Row { + return t.Tx.QueryRow(t.ctx, string(sql), arguments...) } // Select in a transaction. -func (t *Transaction) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error { - return pgxscan.Select(ctx, t.Tx, sliceOfStructPtr, string(sql), arguments...) +func (t *Tx) Select(sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error { + return pgxscan.Select(t.ctx, t.Tx, sliceOfStructPtr, string(sql), arguments...) }