From 07a7a534da53e0646cdc0e48f05d96173e4a2cdb Mon Sep 17 00:00:00 2001 From: Ian Norden Date: Mon, 14 Sep 2020 08:50:51 -0500 Subject: [PATCH] batch.Replay() method --- postgres/batch.go | 39 +++++++++++++++++++++++++++++++-------- postgres/batch_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/postgres/batch.go b/postgres/batch.go index 8e38292..96b9b32 100644 --- a/postgres/batch.go +++ b/postgres/batch.go @@ -17,22 +17,25 @@ package pgipfsethdb import ( + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" "github.com/jmoiron/sqlx" ) // Batch is the type that satisfies the ethdb.Batch interface for PG-IPFS Ethereum data using a direct Postgres connection type Batch struct { - db *sqlx.DB - tx *sqlx.Tx - valueSize int + db *sqlx.DB + tx *sqlx.Tx + valueSize int + replayCache map[string][]byte } // NewBatch returns a ethdb.Batch interface for PG-IPFS func NewBatch(db *sqlx.DB, tx *sqlx.Tx) ethdb.Batch { b := &Batch{ - db: db, - tx: tx, + db: db, + tx: tx, + replayCache: make(map[string][]byte), } if tx == nil { b.Reset() @@ -55,6 +58,7 @@ func (b *Batch) Put(key []byte, value []byte) (err error) { return err } b.valueSize += len(value) + b.replayCache[common.Bytes2Hex(key)] = value return nil } @@ -62,7 +66,11 @@ func (b *Batch) Put(key []byte, value []byte) (err error) { // Delete removes the key from the key-value data store func (b *Batch) Delete(key []byte) (err error) { _, err = b.tx.Exec(deletePgStr, key) - return err + if err != nil { + return err + } + delete(b.replayCache, common.Bytes2Hex(key)) + return nil } // ValueSize satisfies the ethdb.Batch interface @@ -79,13 +87,27 @@ func (b *Batch) Write() error { if b.tx == nil { return nil } - return b.tx.Commit() + if err := b.tx.Commit(); err != nil { + return err + } + b.replayCache = nil + return nil } // Replay satisfies the ethdb.Batch interface // Replay replays the batch contents func (b *Batch) Replay(w ethdb.KeyValueWriter) error { - return errNotSupported + if b.tx != nil { + b.tx.Rollback() + b.tx = nil + } + for key, value := range b.replayCache { + if err := w.Put(common.Hex2Bytes(key), value); err != nil { + return err + } + } + b.replayCache = nil + return nil } // Reset satisfies the ethdb.Batch interface @@ -97,5 +119,6 @@ func (b *Batch) Reset() { if err != nil { panic(err) } + b.replayCache = make(map[string][]byte) b.valueSize = 0 } diff --git a/postgres/batch_test.go b/postgres/batch_test.go index 5d89168..f941be6 100644 --- a/postgres/batch_test.go +++ b/postgres/batch_test.go @@ -115,4 +115,30 @@ var _ = Describe("Batch", func() { Expect(size).To(Equal(0)) }) }) + + Describe("Replay", func() { + It("returns the size of data in the batch queued for write", func() { + err = batch.Put(testKeccakEthKey, testValue) + Expect(err).ToNot(HaveOccurred()) + err = batch.Put(testKeccakEthKey2, testValue2) + Expect(err).ToNot(HaveOccurred()) + + _, err = database.Get(testKeccakEthKey) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("sql: no rows in result set")) + _, err = database.Get(testKeccakEthKey2) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("sql: no rows in result set")) + + err = batch.Replay(database) + Expect(err).ToNot(HaveOccurred()) + + val, err := database.Get(testKeccakEthKey) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(testValue)) + val2, err := database.Get(testKeccakEthKey2) + Expect(err).ToNot(HaveOccurred()) + Expect(val2).To(Equal(testValue2)) + }) + }) })